xref: /libCEED/backends/cuda-ref/ceed-cuda-ref-vector.c (revision f7c1b517c323a719e436874ff6ee79be51545b0d)
1ff1e7120SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2ff1e7120SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3ff1e7120SSebastian Grimberg //
4ff1e7120SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5ff1e7120SSebastian Grimberg //
6ff1e7120SSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7ff1e7120SSebastian Grimberg 
8ff1e7120SSebastian Grimberg #include <ceed.h>
9ff1e7120SSebastian Grimberg #include <ceed/backend.h>
10ff1e7120SSebastian Grimberg #include <cublas_v2.h>
11ff1e7120SSebastian Grimberg #include <cuda_runtime.h>
12ff1e7120SSebastian Grimberg #include <math.h>
13ff1e7120SSebastian Grimberg #include <stdbool.h>
14ff1e7120SSebastian Grimberg #include <string.h>
15ff1e7120SSebastian Grimberg 
16ff1e7120SSebastian Grimberg #include "../cuda/ceed-cuda-common.h"
17ff1e7120SSebastian Grimberg #include "ceed-cuda-ref.h"
18ff1e7120SSebastian Grimberg 
19ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
20ff1e7120SSebastian Grimberg // Check if host/device sync is needed
21ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
22ff1e7120SSebastian Grimberg static inline int CeedVectorNeedSync_Cuda(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
23ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
24ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
25ff1e7120SSebastian Grimberg 
26ff1e7120SSebastian Grimberg   bool has_valid_array = false;
27ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
28ff1e7120SSebastian Grimberg   switch (mem_type) {
29ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
30ff1e7120SSebastian Grimberg       *need_sync = has_valid_array && !impl->h_array;
31ff1e7120SSebastian Grimberg       break;
32ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
33ff1e7120SSebastian Grimberg       *need_sync = has_valid_array && !impl->d_array;
34ff1e7120SSebastian Grimberg       break;
35ff1e7120SSebastian Grimberg   }
36ff1e7120SSebastian Grimberg 
37ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
38ff1e7120SSebastian Grimberg }
39ff1e7120SSebastian Grimberg 
40ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
41ff1e7120SSebastian Grimberg // Sync host to device
42ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
43ff1e7120SSebastian Grimberg static inline int CeedVectorSyncH2D_Cuda(const CeedVector vec) {
44ff1e7120SSebastian Grimberg   Ceed ceed;
45ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
46ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
47ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
48ff1e7120SSebastian Grimberg 
49ff1e7120SSebastian Grimberg   CeedCheck(impl->h_array, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
50ff1e7120SSebastian Grimberg 
51ff1e7120SSebastian Grimberg   CeedSize length;
52ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(vec, &length));
53ff1e7120SSebastian Grimberg   size_t bytes = length * sizeof(CeedScalar);
54ff1e7120SSebastian Grimberg 
55ff1e7120SSebastian Grimberg   if (impl->d_array_borrowed) {
56ff1e7120SSebastian Grimberg     impl->d_array = impl->d_array_borrowed;
57ff1e7120SSebastian Grimberg   } else if (impl->d_array_owned) {
58ff1e7120SSebastian Grimberg     impl->d_array = impl->d_array_owned;
59ff1e7120SSebastian Grimberg   } else {
60ff1e7120SSebastian Grimberg     CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_array_owned, bytes));
61ff1e7120SSebastian Grimberg     impl->d_array = impl->d_array_owned;
62ff1e7120SSebastian Grimberg   }
63ff1e7120SSebastian Grimberg 
64ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaMemcpy(impl->d_array, impl->h_array, bytes, cudaMemcpyHostToDevice));
65ff1e7120SSebastian Grimberg 
66ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
67ff1e7120SSebastian Grimberg }
68ff1e7120SSebastian Grimberg 
69ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
70ff1e7120SSebastian Grimberg // Sync device to host
71ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
72ff1e7120SSebastian Grimberg static inline int CeedVectorSyncD2H_Cuda(const CeedVector vec) {
73ff1e7120SSebastian Grimberg   Ceed ceed;
74ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
75ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
76ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
77ff1e7120SSebastian Grimberg 
78ff1e7120SSebastian Grimberg   CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
79ff1e7120SSebastian Grimberg 
80ff1e7120SSebastian Grimberg   if (impl->h_array_borrowed) {
81ff1e7120SSebastian Grimberg     impl->h_array = impl->h_array_borrowed;
82ff1e7120SSebastian Grimberg   } else if (impl->h_array_owned) {
83ff1e7120SSebastian Grimberg     impl->h_array = impl->h_array_owned;
84ff1e7120SSebastian Grimberg   } else {
85ff1e7120SSebastian Grimberg     CeedSize length;
86ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorGetLength(vec, &length));
87ff1e7120SSebastian Grimberg     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
88ff1e7120SSebastian Grimberg     impl->h_array = impl->h_array_owned;
89ff1e7120SSebastian Grimberg   }
90ff1e7120SSebastian Grimberg 
91ff1e7120SSebastian Grimberg   CeedSize length;
92ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(vec, &length));
93ff1e7120SSebastian Grimberg   size_t bytes = length * sizeof(CeedScalar);
94ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaMemcpy(impl->h_array, impl->d_array, bytes, cudaMemcpyDeviceToHost));
95ff1e7120SSebastian Grimberg 
96ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
97ff1e7120SSebastian Grimberg }
98ff1e7120SSebastian Grimberg 
99ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
100ff1e7120SSebastian Grimberg // Sync arrays
101ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
102ff1e7120SSebastian Grimberg static int CeedVectorSyncArray_Cuda(const CeedVector vec, CeedMemType mem_type) {
103ff1e7120SSebastian Grimberg   // Check whether device/host sync is needed
104ff1e7120SSebastian Grimberg   bool need_sync = false;
105ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync));
106ff1e7120SSebastian Grimberg   if (!need_sync) return CEED_ERROR_SUCCESS;
107ff1e7120SSebastian Grimberg 
108ff1e7120SSebastian Grimberg   switch (mem_type) {
109ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
110ff1e7120SSebastian Grimberg       return CeedVectorSyncD2H_Cuda(vec);
111ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
112ff1e7120SSebastian Grimberg       return CeedVectorSyncH2D_Cuda(vec);
113ff1e7120SSebastian Grimberg   }
114ff1e7120SSebastian Grimberg   return CEED_ERROR_UNSUPPORTED;
115ff1e7120SSebastian Grimberg }
116ff1e7120SSebastian Grimberg 
117ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
118ff1e7120SSebastian Grimberg // Set all pointers as invalid
119ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
120ff1e7120SSebastian Grimberg static inline int CeedVectorSetAllInvalid_Cuda(const CeedVector vec) {
121ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
122ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
123ff1e7120SSebastian Grimberg 
124ff1e7120SSebastian Grimberg   impl->h_array = NULL;
125ff1e7120SSebastian Grimberg   impl->d_array = NULL;
126ff1e7120SSebastian Grimberg 
127ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
128ff1e7120SSebastian Grimberg }
129ff1e7120SSebastian Grimberg 
130ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
131ff1e7120SSebastian Grimberg // Check if CeedVector has any valid pointer
132ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
133ff1e7120SSebastian Grimberg static inline int CeedVectorHasValidArray_Cuda(const CeedVector vec, bool *has_valid_array) {
134ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
135ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
136ff1e7120SSebastian Grimberg 
137ff1e7120SSebastian Grimberg   *has_valid_array = !!impl->h_array || !!impl->d_array;
138ff1e7120SSebastian Grimberg 
139ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
140ff1e7120SSebastian Grimberg }
141ff1e7120SSebastian Grimberg 
142ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
143ff1e7120SSebastian Grimberg // Check if has array of given type
144ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
145ff1e7120SSebastian Grimberg static inline int CeedVectorHasArrayOfType_Cuda(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
146ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
147ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
148ff1e7120SSebastian Grimberg 
149ff1e7120SSebastian Grimberg   switch (mem_type) {
150ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
151ff1e7120SSebastian Grimberg       *has_array_of_type = !!impl->h_array_borrowed || !!impl->h_array_owned;
152ff1e7120SSebastian Grimberg       break;
153ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
154ff1e7120SSebastian Grimberg       *has_array_of_type = !!impl->d_array_borrowed || !!impl->d_array_owned;
155ff1e7120SSebastian Grimberg       break;
156ff1e7120SSebastian Grimberg   }
157ff1e7120SSebastian Grimberg 
158ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
159ff1e7120SSebastian Grimberg }
160ff1e7120SSebastian Grimberg 
161ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
162ff1e7120SSebastian Grimberg // Check if has borrowed array of given type
163ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
164ff1e7120SSebastian Grimberg static inline int CeedVectorHasBorrowedArrayOfType_Cuda(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
165ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
166ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
167ff1e7120SSebastian Grimberg 
168ff1e7120SSebastian Grimberg   switch (mem_type) {
169ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
170ff1e7120SSebastian Grimberg       *has_borrowed_array_of_type = !!impl->h_array_borrowed;
171ff1e7120SSebastian Grimberg       break;
172ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
173ff1e7120SSebastian Grimberg       *has_borrowed_array_of_type = !!impl->d_array_borrowed;
174ff1e7120SSebastian Grimberg       break;
175ff1e7120SSebastian Grimberg   }
176ff1e7120SSebastian Grimberg 
177ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
178ff1e7120SSebastian Grimberg }
179ff1e7120SSebastian Grimberg 
180ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
181ff1e7120SSebastian Grimberg // Set array from host
182ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
183ff1e7120SSebastian Grimberg static int CeedVectorSetArrayHost_Cuda(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
184ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
185ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
186ff1e7120SSebastian Grimberg 
187ff1e7120SSebastian Grimberg   switch (copy_mode) {
188ff1e7120SSebastian Grimberg     case CEED_COPY_VALUES: {
189ff1e7120SSebastian Grimberg       CeedSize length;
190ff1e7120SSebastian Grimberg       if (!impl->h_array_owned) {
191ff1e7120SSebastian Grimberg         CeedCallBackend(CeedVectorGetLength(vec, &length));
192ff1e7120SSebastian Grimberg         CeedCallBackend(CeedMalloc(length, &impl->h_array_owned));
193ff1e7120SSebastian Grimberg       }
194ff1e7120SSebastian Grimberg       impl->h_array_borrowed = NULL;
195ff1e7120SSebastian Grimberg       impl->h_array          = impl->h_array_owned;
196ff1e7120SSebastian Grimberg       if (array) {
197ff1e7120SSebastian Grimberg         CeedSize length;
198ff1e7120SSebastian Grimberg         CeedCallBackend(CeedVectorGetLength(vec, &length));
199ff1e7120SSebastian Grimberg         size_t bytes = length * sizeof(CeedScalar);
200ff1e7120SSebastian Grimberg         memcpy(impl->h_array, array, bytes);
201ff1e7120SSebastian Grimberg       }
202ff1e7120SSebastian Grimberg     } break;
203ff1e7120SSebastian Grimberg     case CEED_OWN_POINTER:
204ff1e7120SSebastian Grimberg       CeedCallBackend(CeedFree(&impl->h_array_owned));
205ff1e7120SSebastian Grimberg       impl->h_array_owned    = array;
206ff1e7120SSebastian Grimberg       impl->h_array_borrowed = NULL;
207ff1e7120SSebastian Grimberg       impl->h_array          = array;
208ff1e7120SSebastian Grimberg       break;
209ff1e7120SSebastian Grimberg     case CEED_USE_POINTER:
210ff1e7120SSebastian Grimberg       CeedCallBackend(CeedFree(&impl->h_array_owned));
211ff1e7120SSebastian Grimberg       impl->h_array_borrowed = array;
212ff1e7120SSebastian Grimberg       impl->h_array          = array;
213ff1e7120SSebastian Grimberg       break;
214ff1e7120SSebastian Grimberg   }
215ff1e7120SSebastian Grimberg 
216ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
217ff1e7120SSebastian Grimberg }
218ff1e7120SSebastian Grimberg 
219ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
220ff1e7120SSebastian Grimberg // Set array from device
221ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
222ff1e7120SSebastian Grimberg static int CeedVectorSetArrayDevice_Cuda(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
223ff1e7120SSebastian Grimberg   Ceed ceed;
224ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
225ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
226ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
227ff1e7120SSebastian Grimberg 
228ff1e7120SSebastian Grimberg   switch (copy_mode) {
229ff1e7120SSebastian Grimberg     case CEED_COPY_VALUES: {
230ff1e7120SSebastian Grimberg       CeedSize length;
231ff1e7120SSebastian Grimberg       CeedCallBackend(CeedVectorGetLength(vec, &length));
232ff1e7120SSebastian Grimberg       size_t bytes = length * sizeof(CeedScalar);
233ff1e7120SSebastian Grimberg       if (!impl->d_array_owned) {
234ff1e7120SSebastian Grimberg         CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_array_owned, bytes));
235ff1e7120SSebastian Grimberg         impl->d_array = impl->d_array_owned;
236ff1e7120SSebastian Grimberg       }
237ff1e7120SSebastian Grimberg       if (array) CeedCallCuda(ceed, cudaMemcpy(impl->d_array, array, bytes, cudaMemcpyDeviceToDevice));
238ff1e7120SSebastian Grimberg     } break;
239ff1e7120SSebastian Grimberg     case CEED_OWN_POINTER:
240ff1e7120SSebastian Grimberg       CeedCallCuda(ceed, cudaFree(impl->d_array_owned));
241ff1e7120SSebastian Grimberg       impl->d_array_owned    = array;
242ff1e7120SSebastian Grimberg       impl->d_array_borrowed = NULL;
243ff1e7120SSebastian Grimberg       impl->d_array          = array;
244ff1e7120SSebastian Grimberg       break;
245ff1e7120SSebastian Grimberg     case CEED_USE_POINTER:
246ff1e7120SSebastian Grimberg       CeedCallCuda(ceed, cudaFree(impl->d_array_owned));
247ff1e7120SSebastian Grimberg       impl->d_array_owned    = NULL;
248ff1e7120SSebastian Grimberg       impl->d_array_borrowed = array;
249ff1e7120SSebastian Grimberg       impl->d_array          = array;
250ff1e7120SSebastian Grimberg       break;
251ff1e7120SSebastian Grimberg   }
252ff1e7120SSebastian Grimberg 
253ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
254ff1e7120SSebastian Grimberg }
255ff1e7120SSebastian Grimberg 
256ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
257ff1e7120SSebastian Grimberg // Set the array used by a vector,
258ff1e7120SSebastian Grimberg //   freeing any previously allocated array if applicable
259ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
260ff1e7120SSebastian Grimberg static int CeedVectorSetArray_Cuda(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
261ff1e7120SSebastian Grimberg   Ceed ceed;
262ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
263ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
264ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
265ff1e7120SSebastian Grimberg 
266ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorSetAllInvalid_Cuda(vec));
267ff1e7120SSebastian Grimberg   switch (mem_type) {
268ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
269ff1e7120SSebastian Grimberg       return CeedVectorSetArrayHost_Cuda(vec, copy_mode, array);
270ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
271ff1e7120SSebastian Grimberg       return CeedVectorSetArrayDevice_Cuda(vec, copy_mode, array);
272ff1e7120SSebastian Grimberg   }
273ff1e7120SSebastian Grimberg 
274ff1e7120SSebastian Grimberg   return CEED_ERROR_UNSUPPORTED;
275ff1e7120SSebastian Grimberg }
276ff1e7120SSebastian Grimberg 
277ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
278ff1e7120SSebastian Grimberg // Set host array to value
279ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
280*f7c1b517Snbeams static int CeedHostSetValue_Cuda(CeedScalar *h_array, CeedSize length, CeedScalar val) {
281*f7c1b517Snbeams   for (CeedSize i = 0; i < length; i++) h_array[i] = val;
282ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
283ff1e7120SSebastian Grimberg }
284ff1e7120SSebastian Grimberg 
285ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
286ff1e7120SSebastian Grimberg // Set device array to value (impl in .cu file)
287ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
288*f7c1b517Snbeams int CeedDeviceSetValue_Cuda(CeedScalar *d_array, CeedSize length, CeedScalar val);
289ff1e7120SSebastian Grimberg 
290ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
291*f7c1b517Snbeams // Set a vector to a value
292ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
293ff1e7120SSebastian Grimberg static int CeedVectorSetValue_Cuda(CeedVector vec, CeedScalar val) {
294ff1e7120SSebastian Grimberg   Ceed ceed;
295ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
296ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
297ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
298ff1e7120SSebastian Grimberg   CeedSize length;
299ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(vec, &length));
300ff1e7120SSebastian Grimberg 
301ff1e7120SSebastian Grimberg   // Set value for synced device/host array
302ff1e7120SSebastian Grimberg   if (!impl->d_array && !impl->h_array) {
303ff1e7120SSebastian Grimberg     if (impl->d_array_borrowed) {
304ff1e7120SSebastian Grimberg       impl->d_array = impl->d_array_borrowed;
305ff1e7120SSebastian Grimberg     } else if (impl->h_array_borrowed) {
306ff1e7120SSebastian Grimberg       impl->h_array = impl->h_array_borrowed;
307ff1e7120SSebastian Grimberg     } else if (impl->d_array_owned) {
308ff1e7120SSebastian Grimberg       impl->d_array = impl->d_array_owned;
309ff1e7120SSebastian Grimberg     } else if (impl->h_array_owned) {
310ff1e7120SSebastian Grimberg       impl->h_array = impl->h_array_owned;
311ff1e7120SSebastian Grimberg     } else {
312ff1e7120SSebastian Grimberg       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
313ff1e7120SSebastian Grimberg     }
314ff1e7120SSebastian Grimberg   }
315ff1e7120SSebastian Grimberg   if (impl->d_array) {
316ff1e7120SSebastian Grimberg     CeedCallBackend(CeedDeviceSetValue_Cuda(impl->d_array, length, val));
317ff1e7120SSebastian Grimberg     impl->h_array = NULL;
318ff1e7120SSebastian Grimberg   }
319ff1e7120SSebastian Grimberg   if (impl->h_array) {
320ff1e7120SSebastian Grimberg     CeedCallBackend(CeedHostSetValue_Cuda(impl->h_array, length, val));
321ff1e7120SSebastian Grimberg     impl->d_array = NULL;
322ff1e7120SSebastian Grimberg   }
323ff1e7120SSebastian Grimberg 
324ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
325ff1e7120SSebastian Grimberg }
326ff1e7120SSebastian Grimberg 
327ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
328ff1e7120SSebastian Grimberg // Vector Take Array
329ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
330ff1e7120SSebastian Grimberg static int CeedVectorTakeArray_Cuda(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
331ff1e7120SSebastian Grimberg   Ceed ceed;
332ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
333ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
334ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
335ff1e7120SSebastian Grimberg 
336ff1e7120SSebastian Grimberg   // Sync array to requested mem_type
337ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
338ff1e7120SSebastian Grimberg 
339ff1e7120SSebastian Grimberg   // Update pointer
340ff1e7120SSebastian Grimberg   switch (mem_type) {
341ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
342ff1e7120SSebastian Grimberg       (*array)               = impl->h_array_borrowed;
343ff1e7120SSebastian Grimberg       impl->h_array_borrowed = NULL;
344ff1e7120SSebastian Grimberg       impl->h_array          = NULL;
345ff1e7120SSebastian Grimberg       break;
346ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
347ff1e7120SSebastian Grimberg       (*array)               = impl->d_array_borrowed;
348ff1e7120SSebastian Grimberg       impl->d_array_borrowed = NULL;
349ff1e7120SSebastian Grimberg       impl->d_array          = NULL;
350ff1e7120SSebastian Grimberg       break;
351ff1e7120SSebastian Grimberg   }
352ff1e7120SSebastian Grimberg 
353ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
354ff1e7120SSebastian Grimberg }
355ff1e7120SSebastian Grimberg 
356ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
357ff1e7120SSebastian Grimberg // Core logic for array syncronization for GetArray.
358ff1e7120SSebastian Grimberg //   If a different memory type is most up to date, this will perform a copy
359ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
360ff1e7120SSebastian Grimberg static int CeedVectorGetArrayCore_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
361ff1e7120SSebastian Grimberg   Ceed ceed;
362ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
363ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
364ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
365ff1e7120SSebastian Grimberg 
366ff1e7120SSebastian Grimberg   // Sync array to requested mem_type
367ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
368ff1e7120SSebastian Grimberg 
369ff1e7120SSebastian Grimberg   // Update pointer
370ff1e7120SSebastian Grimberg   switch (mem_type) {
371ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
372ff1e7120SSebastian Grimberg       *array = impl->h_array;
373ff1e7120SSebastian Grimberg       break;
374ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
375ff1e7120SSebastian Grimberg       *array = impl->d_array;
376ff1e7120SSebastian Grimberg       break;
377ff1e7120SSebastian Grimberg   }
378ff1e7120SSebastian Grimberg 
379ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
380ff1e7120SSebastian Grimberg }
381ff1e7120SSebastian Grimberg 
382ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
383ff1e7120SSebastian Grimberg // Get read-only access to a vector via the specified mem_type
384ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
385ff1e7120SSebastian Grimberg static int CeedVectorGetArrayRead_Cuda(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
386ff1e7120SSebastian Grimberg   return CeedVectorGetArrayCore_Cuda(vec, mem_type, (CeedScalar **)array);
387ff1e7120SSebastian Grimberg }
388ff1e7120SSebastian Grimberg 
389ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
390ff1e7120SSebastian Grimberg // Get read/write access to a vector via the specified mem_type
391ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
392ff1e7120SSebastian Grimberg static int CeedVectorGetArray_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
393ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
394ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
395ff1e7120SSebastian Grimberg 
396ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetArrayCore_Cuda(vec, mem_type, array));
397ff1e7120SSebastian Grimberg 
398ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorSetAllInvalid_Cuda(vec));
399ff1e7120SSebastian Grimberg   switch (mem_type) {
400ff1e7120SSebastian Grimberg     case CEED_MEM_HOST:
401ff1e7120SSebastian Grimberg       impl->h_array = *array;
402ff1e7120SSebastian Grimberg       break;
403ff1e7120SSebastian Grimberg     case CEED_MEM_DEVICE:
404ff1e7120SSebastian Grimberg       impl->d_array = *array;
405ff1e7120SSebastian Grimberg       break;
406ff1e7120SSebastian Grimberg   }
407ff1e7120SSebastian Grimberg 
408ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
409ff1e7120SSebastian Grimberg }
410ff1e7120SSebastian Grimberg 
411ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
412ff1e7120SSebastian Grimberg // Get write access to a vector via the specified mem_type
413ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
414ff1e7120SSebastian Grimberg static int CeedVectorGetArrayWrite_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
415ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
416ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
417ff1e7120SSebastian Grimberg 
418ff1e7120SSebastian Grimberg   bool has_array_of_type = true;
419ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorHasArrayOfType_Cuda(vec, mem_type, &has_array_of_type));
420ff1e7120SSebastian Grimberg   if (!has_array_of_type) {
421ff1e7120SSebastian Grimberg     // Allocate if array is not yet allocated
422ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
423ff1e7120SSebastian Grimberg   } else {
424ff1e7120SSebastian Grimberg     // Select dirty array
425ff1e7120SSebastian Grimberg     switch (mem_type) {
426ff1e7120SSebastian Grimberg       case CEED_MEM_HOST:
427ff1e7120SSebastian Grimberg         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
428ff1e7120SSebastian Grimberg         else impl->h_array = impl->h_array_owned;
429ff1e7120SSebastian Grimberg         break;
430ff1e7120SSebastian Grimberg       case CEED_MEM_DEVICE:
431ff1e7120SSebastian Grimberg         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
432ff1e7120SSebastian Grimberg         else impl->d_array = impl->d_array_owned;
433ff1e7120SSebastian Grimberg     }
434ff1e7120SSebastian Grimberg   }
435ff1e7120SSebastian Grimberg 
436ff1e7120SSebastian Grimberg   return CeedVectorGetArray_Cuda(vec, mem_type, array);
437ff1e7120SSebastian Grimberg }
438ff1e7120SSebastian Grimberg 
439ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
440ff1e7120SSebastian Grimberg // Get the norm of a CeedVector
441ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
442ff1e7120SSebastian Grimberg static int CeedVectorNorm_Cuda(CeedVector vec, CeedNormType type, CeedScalar *norm) {
443ff1e7120SSebastian Grimberg   Ceed ceed;
444ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
445ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
446ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
447ff1e7120SSebastian Grimberg   CeedSize length;
448ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(vec, &length));
449ff1e7120SSebastian Grimberg   cublasHandle_t handle;
450ff1e7120SSebastian Grimberg   CeedCallBackend(CeedGetCublasHandle_Cuda(ceed, &handle));
451ff1e7120SSebastian Grimberg 
452*f7c1b517Snbeams #if CUDA_VERSION < 12000
453*f7c1b517Snbeams   // With CUDA 12, we can use the 64-bit integer interface. Prior to that,
454*f7c1b517Snbeams   // we need to check if the vector is too long to handle with int32,
455*f7c1b517Snbeams   // and if so, divide it into subsections for repeated cuBLAS calls
456*f7c1b517Snbeams   CeedSize num_calls = length / INT_MAX;
457*f7c1b517Snbeams   if (length % INT_MAX > 0) num_calls += 1;
458*f7c1b517Snbeams #endif
459*f7c1b517Snbeams 
460ff1e7120SSebastian Grimberg   // Compute norm
461ff1e7120SSebastian Grimberg   const CeedScalar *d_array;
462ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
463ff1e7120SSebastian Grimberg   switch (type) {
464ff1e7120SSebastian Grimberg     case CEED_NORM_1: {
465ff1e7120SSebastian Grimberg       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
466*f7c1b517Snbeams #if CUDA_VERSION >= 12000  // We have CUDA 12, and can use 64-bit integers
467*f7c1b517Snbeams         CeedCallCublas(ceed, cublasSasum_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
468*f7c1b517Snbeams #else
469*f7c1b517Snbeams         float  sub_norm = 0.0;
470*f7c1b517Snbeams         float *d_array_start;
471*f7c1b517Snbeams         for (CeedInt i = 0; i < num_calls; i++) {
472*f7c1b517Snbeams           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
473*f7c1b517Snbeams           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
474*f7c1b517Snbeams           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
475*f7c1b517Snbeams           CeedCallCublas(ceed, cublasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
476*f7c1b517Snbeams           *norm += sub_norm;
477*f7c1b517Snbeams         }
478*f7c1b517Snbeams #endif
479ff1e7120SSebastian Grimberg       } else {
480*f7c1b517Snbeams #if CUDA_VERSION >= 12000
481*f7c1b517Snbeams         CeedCallCublas(ceed, cublasDasum_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
482*f7c1b517Snbeams #else
483*f7c1b517Snbeams         double  sub_norm = 0.0;
484*f7c1b517Snbeams         double *d_array_start;
485*f7c1b517Snbeams         for (CeedInt i = 0; i < num_calls; i++) {
486*f7c1b517Snbeams           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
487*f7c1b517Snbeams           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
488*f7c1b517Snbeams           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
489*f7c1b517Snbeams           CeedCallCublas(ceed, cublasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
490*f7c1b517Snbeams           *norm += sub_norm;
491*f7c1b517Snbeams         }
492*f7c1b517Snbeams #endif
493ff1e7120SSebastian Grimberg       }
494ff1e7120SSebastian Grimberg       break;
495ff1e7120SSebastian Grimberg     }
496ff1e7120SSebastian Grimberg     case CEED_NORM_2: {
497ff1e7120SSebastian Grimberg       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
498*f7c1b517Snbeams #if CUDA_VERSION >= 12000
499*f7c1b517Snbeams         CeedCallCublas(ceed, cublasSnrm2_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
500*f7c1b517Snbeams #else
501*f7c1b517Snbeams         float  sub_norm = 0.0, norm_sum = 0.0;
502*f7c1b517Snbeams         float *d_array_start;
503*f7c1b517Snbeams         for (CeedInt i = 0; i < num_calls; i++) {
504*f7c1b517Snbeams           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
505*f7c1b517Snbeams           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
506*f7c1b517Snbeams           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
507*f7c1b517Snbeams           CeedCallCublas(ceed, cublasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
508*f7c1b517Snbeams           norm_sum += sub_norm * sub_norm;
509*f7c1b517Snbeams         }
510*f7c1b517Snbeams         *norm            = sqrt(norm_sum);
511*f7c1b517Snbeams #endif
512ff1e7120SSebastian Grimberg       } else {
513*f7c1b517Snbeams #if CUDA_VERSION >= 12000
514*f7c1b517Snbeams         CeedCallCublas(ceed, cublasDnrm2_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
515*f7c1b517Snbeams #else
516*f7c1b517Snbeams         double  sub_norm = 0.0, norm_sum = 0.0;
517*f7c1b517Snbeams         double *d_array_start;
518*f7c1b517Snbeams         for (CeedInt i = 0; i < num_calls; i++) {
519*f7c1b517Snbeams           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
520*f7c1b517Snbeams           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
521*f7c1b517Snbeams           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
522*f7c1b517Snbeams           CeedCallCublas(ceed, cublasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
523*f7c1b517Snbeams           norm_sum += sub_norm * sub_norm;
524*f7c1b517Snbeams         }
525*f7c1b517Snbeams         *norm = sqrt(norm_sum);
526*f7c1b517Snbeams #endif
527ff1e7120SSebastian Grimberg       }
528ff1e7120SSebastian Grimberg       break;
529ff1e7120SSebastian Grimberg     }
530ff1e7120SSebastian Grimberg     case CEED_NORM_MAX: {
531ff1e7120SSebastian Grimberg       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
532*f7c1b517Snbeams #if CUDA_VERSION >= 12000
533*f7c1b517Snbeams         int64_t indx;
534*f7c1b517Snbeams         CeedCallCublas(ceed, cublasIsamax_64(handle, (int64_t)length, (float *)d_array, 1, &indx));
535ff1e7120SSebastian Grimberg         CeedScalar normNoAbs;
536ff1e7120SSebastian Grimberg         CeedCallCuda(ceed, cudaMemcpy(&normNoAbs, impl->d_array + indx - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
537ff1e7120SSebastian Grimberg         *norm = fabs(normNoAbs);
538*f7c1b517Snbeams #else
539*f7c1b517Snbeams         CeedInt indx;
540*f7c1b517Snbeams         float   sub_max = 0.0, current_max = 0.0;
541*f7c1b517Snbeams         float  *d_array_start;
542*f7c1b517Snbeams         for (CeedInt i = 0; i < num_calls; i++) {
543*f7c1b517Snbeams           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
544*f7c1b517Snbeams           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
545*f7c1b517Snbeams           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
546*f7c1b517Snbeams           CeedCallCublas(ceed, cublasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &indx));
547*f7c1b517Snbeams           CeedCallCuda(ceed, cudaMemcpy(&sub_max, d_array_start + indx - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
548*f7c1b517Snbeams           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
549*f7c1b517Snbeams         }
550*f7c1b517Snbeams         *norm = current_max;
551*f7c1b517Snbeams #endif
552*f7c1b517Snbeams       } else {
553*f7c1b517Snbeams #if CUDA_VERSION >= 12000
554*f7c1b517Snbeams         int64_t indx;
555*f7c1b517Snbeams         CeedCallCublas(ceed, cublasIdamax_64(handle, (int64_t)length, (double *)d_array, 1, &indx));
556*f7c1b517Snbeams         CeedScalar normNoAbs;
557*f7c1b517Snbeams         CeedCallCuda(ceed, cudaMemcpy(&normNoAbs, impl->d_array + indx - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
558*f7c1b517Snbeams         *norm = fabs(normNoAbs);
559*f7c1b517Snbeams #else
560*f7c1b517Snbeams         CeedInt indx;
561*f7c1b517Snbeams         double  sub_max = 0.0, current_max = 0.0;
562*f7c1b517Snbeams         double *d_array_start;
563*f7c1b517Snbeams         for (CeedInt i = 0; i < num_calls; i++) {
564*f7c1b517Snbeams           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
565*f7c1b517Snbeams           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
566*f7c1b517Snbeams           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
567*f7c1b517Snbeams           CeedCallCublas(ceed, cublasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &indx));
568*f7c1b517Snbeams           CeedCallCuda(ceed, cudaMemcpy(&sub_max, d_array_start + indx - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
569*f7c1b517Snbeams           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
570*f7c1b517Snbeams         }
571*f7c1b517Snbeams         *norm = current_max;
572*f7c1b517Snbeams #endif
573*f7c1b517Snbeams       }
574ff1e7120SSebastian Grimberg       break;
575ff1e7120SSebastian Grimberg     }
576ff1e7120SSebastian Grimberg   }
577ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
578ff1e7120SSebastian Grimberg 
579ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
580ff1e7120SSebastian Grimberg }
581ff1e7120SSebastian Grimberg 
582ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
583ff1e7120SSebastian Grimberg // Take reciprocal of a vector on host
584ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
585*f7c1b517Snbeams static int CeedHostReciprocal_Cuda(CeedScalar *h_array, CeedSize length) {
586*f7c1b517Snbeams   for (CeedSize i = 0; i < length; i++) {
587ff1e7120SSebastian Grimberg     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
588ff1e7120SSebastian Grimberg   }
589ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
590ff1e7120SSebastian Grimberg }
591ff1e7120SSebastian Grimberg 
592ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
593ff1e7120SSebastian Grimberg // Take reciprocal of a vector on device (impl in .cu file)
594ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
595*f7c1b517Snbeams int CeedDeviceReciprocal_Cuda(CeedScalar *d_array, CeedSize length);
596ff1e7120SSebastian Grimberg 
597ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
598ff1e7120SSebastian Grimberg // Take reciprocal of a vector
599ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
600ff1e7120SSebastian Grimberg static int CeedVectorReciprocal_Cuda(CeedVector vec) {
601ff1e7120SSebastian Grimberg   Ceed ceed;
602ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
603ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
604ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
605ff1e7120SSebastian Grimberg   CeedSize length;
606ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(vec, &length));
607ff1e7120SSebastian Grimberg 
608ff1e7120SSebastian Grimberg   // Set value for synced device/host array
609ff1e7120SSebastian Grimberg   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Cuda(impl->d_array, length));
610ff1e7120SSebastian Grimberg   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Cuda(impl->h_array, length));
611ff1e7120SSebastian Grimberg 
612ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
613ff1e7120SSebastian Grimberg }
614ff1e7120SSebastian Grimberg 
615ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
616ff1e7120SSebastian Grimberg // Compute x = alpha x on the host
617ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
618*f7c1b517Snbeams static int CeedHostScale_Cuda(CeedScalar *x_array, CeedScalar alpha, CeedSize length) {
619*f7c1b517Snbeams   for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha;
620ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
621ff1e7120SSebastian Grimberg }
622ff1e7120SSebastian Grimberg 
623ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
624ff1e7120SSebastian Grimberg // Compute x = alpha x on device (impl in .cu file)
625ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
626*f7c1b517Snbeams int CeedDeviceScale_Cuda(CeedScalar *x_array, CeedScalar alpha, CeedSize length);
627ff1e7120SSebastian Grimberg 
628ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
629ff1e7120SSebastian Grimberg // Compute x = alpha x
630ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
631ff1e7120SSebastian Grimberg static int CeedVectorScale_Cuda(CeedVector x, CeedScalar alpha) {
632ff1e7120SSebastian Grimberg   Ceed ceed;
633ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(x, &ceed));
634ff1e7120SSebastian Grimberg   CeedVector_Cuda *x_impl;
635ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(x, &x_impl));
636ff1e7120SSebastian Grimberg   CeedSize length;
637ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(x, &length));
638ff1e7120SSebastian Grimberg 
639ff1e7120SSebastian Grimberg   // Set value for synced device/host array
640ff1e7120SSebastian Grimberg   if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Cuda(x_impl->d_array, alpha, length));
641ff1e7120SSebastian Grimberg   if (x_impl->h_array) CeedCallBackend(CeedHostScale_Cuda(x_impl->h_array, alpha, length));
642ff1e7120SSebastian Grimberg 
643ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
644ff1e7120SSebastian Grimberg }
645ff1e7120SSebastian Grimberg 
646ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
647ff1e7120SSebastian Grimberg // Compute y = alpha x + y on the host
648ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
649*f7c1b517Snbeams static int CeedHostAXPY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) {
650*f7c1b517Snbeams   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
651ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
652ff1e7120SSebastian Grimberg }
653ff1e7120SSebastian Grimberg 
654ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
655ff1e7120SSebastian Grimberg // Compute y = alpha x + y on device (impl in .cu file)
656ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
657*f7c1b517Snbeams int CeedDeviceAXPY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length);
658ff1e7120SSebastian Grimberg 
659ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
660ff1e7120SSebastian Grimberg // Compute y = alpha x + y
661ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
662ff1e7120SSebastian Grimberg static int CeedVectorAXPY_Cuda(CeedVector y, CeedScalar alpha, CeedVector x) {
663ff1e7120SSebastian Grimberg   Ceed ceed;
664ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
665ff1e7120SSebastian Grimberg   CeedVector_Cuda *y_impl, *x_impl;
666ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(y, &y_impl));
667ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(x, &x_impl));
668ff1e7120SSebastian Grimberg   CeedSize length;
669ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(y, &length));
670ff1e7120SSebastian Grimberg 
671ff1e7120SSebastian Grimberg   // Set value for synced device/host array
672ff1e7120SSebastian Grimberg   if (y_impl->d_array) {
673ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
674ff1e7120SSebastian Grimberg     CeedCallBackend(CeedDeviceAXPY_Cuda(y_impl->d_array, alpha, x_impl->d_array, length));
675ff1e7120SSebastian Grimberg   }
676ff1e7120SSebastian Grimberg   if (y_impl->h_array) {
677ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
678ff1e7120SSebastian Grimberg     CeedCallBackend(CeedHostAXPY_Cuda(y_impl->h_array, alpha, x_impl->h_array, length));
679ff1e7120SSebastian Grimberg   }
680ff1e7120SSebastian Grimberg 
681ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
682ff1e7120SSebastian Grimberg }
683ff1e7120SSebastian Grimberg 
684ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
685ff1e7120SSebastian Grimberg // Compute y = alpha x + beta y on the host
686ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
687*f7c1b517Snbeams static int CeedHostAXPBY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) {
688*f7c1b517Snbeams   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i] + beta * y_array[i];
689ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
690ff1e7120SSebastian Grimberg }
691ff1e7120SSebastian Grimberg 
692ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
693ff1e7120SSebastian Grimberg // Compute y = alpha x + beta y on device (impl in .cu file)
694ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
695*f7c1b517Snbeams int CeedDeviceAXPBY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length);
696ff1e7120SSebastian Grimberg 
697ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
698ff1e7120SSebastian Grimberg // Compute y = alpha x + beta y
699ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
700ff1e7120SSebastian Grimberg static int CeedVectorAXPBY_Cuda(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
701ff1e7120SSebastian Grimberg   Ceed ceed;
702ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
703ff1e7120SSebastian Grimberg   CeedVector_Cuda *y_impl, *x_impl;
704ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(y, &y_impl));
705ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(x, &x_impl));
706ff1e7120SSebastian Grimberg   CeedSize length;
707ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(y, &length));
708ff1e7120SSebastian Grimberg 
709ff1e7120SSebastian Grimberg   // Set value for synced device/host array
710ff1e7120SSebastian Grimberg   if (y_impl->d_array) {
711ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
712ff1e7120SSebastian Grimberg     CeedCallBackend(CeedDeviceAXPBY_Cuda(y_impl->d_array, alpha, beta, x_impl->d_array, length));
713ff1e7120SSebastian Grimberg   }
714ff1e7120SSebastian Grimberg   if (y_impl->h_array) {
715ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
716ff1e7120SSebastian Grimberg     CeedCallBackend(CeedHostAXPBY_Cuda(y_impl->h_array, alpha, beta, x_impl->h_array, length));
717ff1e7120SSebastian Grimberg   }
718ff1e7120SSebastian Grimberg 
719ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
720ff1e7120SSebastian Grimberg }
721ff1e7120SSebastian Grimberg 
722ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
723ff1e7120SSebastian Grimberg // Compute the pointwise multiplication w = x .* y on the host
724ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
725*f7c1b517Snbeams static int CeedHostPointwiseMult_Cuda(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) {
726*f7c1b517Snbeams   for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
727ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
728ff1e7120SSebastian Grimberg }
729ff1e7120SSebastian Grimberg 
730ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
731ff1e7120SSebastian Grimberg // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
732ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
733*f7c1b517Snbeams int CeedDevicePointwiseMult_Cuda(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length);
734ff1e7120SSebastian Grimberg 
735ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
736ff1e7120SSebastian Grimberg // Compute the pointwise multiplication w = x .* y
737ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
738ff1e7120SSebastian Grimberg static int CeedVectorPointwiseMult_Cuda(CeedVector w, CeedVector x, CeedVector y) {
739ff1e7120SSebastian Grimberg   Ceed ceed;
740ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(w, &ceed));
741ff1e7120SSebastian Grimberg   CeedVector_Cuda *w_impl, *x_impl, *y_impl;
742ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(w, &w_impl));
743ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(x, &x_impl));
744ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(y, &y_impl));
745ff1e7120SSebastian Grimberg   CeedSize length;
746ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetLength(w, &length));
747ff1e7120SSebastian Grimberg 
748ff1e7120SSebastian Grimberg   // Set value for synced device/host array
749ff1e7120SSebastian Grimberg   if (!w_impl->d_array && !w_impl->h_array) {
750ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSetValue(w, 0.0));
751ff1e7120SSebastian Grimberg   }
752ff1e7120SSebastian Grimberg   if (w_impl->d_array) {
753ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
754ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
755ff1e7120SSebastian Grimberg     CeedCallBackend(CeedDevicePointwiseMult_Cuda(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
756ff1e7120SSebastian Grimberg   }
757ff1e7120SSebastian Grimberg   if (w_impl->h_array) {
758ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
759ff1e7120SSebastian Grimberg     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
760ff1e7120SSebastian Grimberg     CeedCallBackend(CeedHostPointwiseMult_Cuda(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
761ff1e7120SSebastian Grimberg   }
762ff1e7120SSebastian Grimberg 
763ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
764ff1e7120SSebastian Grimberg }
765ff1e7120SSebastian Grimberg 
766ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
767ff1e7120SSebastian Grimberg // Destroy the vector
768ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
769ff1e7120SSebastian Grimberg static int CeedVectorDestroy_Cuda(const CeedVector vec) {
770ff1e7120SSebastian Grimberg   Ceed ceed;
771ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
772ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
773ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetData(vec, &impl));
774ff1e7120SSebastian Grimberg 
775ff1e7120SSebastian Grimberg   CeedCallCuda(ceed, cudaFree(impl->d_array_owned));
776ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&impl->h_array_owned));
777ff1e7120SSebastian Grimberg   CeedCallBackend(CeedFree(&impl));
778ff1e7120SSebastian Grimberg 
779ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
780ff1e7120SSebastian Grimberg }
781ff1e7120SSebastian Grimberg 
782ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
783ff1e7120SSebastian Grimberg // Create a vector of the specified length (does not allocate memory)
784ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
785ff1e7120SSebastian Grimberg int CeedVectorCreate_Cuda(CeedSize n, CeedVector vec) {
786ff1e7120SSebastian Grimberg   CeedVector_Cuda *impl;
787ff1e7120SSebastian Grimberg   Ceed             ceed;
788ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
789ff1e7120SSebastian Grimberg 
790ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Cuda));
791ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Cuda));
792ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Cuda));
793ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Cuda));
794ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", (int (*)())CeedVectorSetValue_Cuda));
795ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Cuda));
796ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Cuda));
797ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Cuda));
798ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Cuda));
799ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Cuda));
800ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Cuda));
801ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", (int (*)())CeedVectorScale_Cuda));
802ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", (int (*)())CeedVectorAXPY_Cuda));
803ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", (int (*)())CeedVectorAXPBY_Cuda));
804ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Cuda));
805ff1e7120SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Cuda));
806ff1e7120SSebastian Grimberg 
807ff1e7120SSebastian Grimberg   CeedCallBackend(CeedCalloc(1, &impl));
808ff1e7120SSebastian Grimberg   CeedCallBackend(CeedVectorSetData(vec, impl));
809ff1e7120SSebastian Grimberg 
810ff1e7120SSebastian Grimberg   return CEED_ERROR_SUCCESS;
811ff1e7120SSebastian Grimberg }
812ff1e7120SSebastian Grimberg 
813ff1e7120SSebastian Grimberg //------------------------------------------------------------------------------
814