xref: /libCEED/backends/cuda-ref/ceed-cuda-ref-vector.c (revision e6cb4fca5f98bfb30d17a8599499ec279b90b048)
1 // Copyright (c) 2017-2025, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <cuda_runtime.h>
11 #include <math.h>
12 #include <stdbool.h>
13 #include <string.h>
14 
15 #include "../cuda/ceed-cuda-common.h"
16 #include "ceed-cuda-ref.h"
17 
18 //------------------------------------------------------------------------------
19 // Check if host/device sync is needed
20 //------------------------------------------------------------------------------
21 static inline int CeedVectorNeedSync_Cuda(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
22   bool             has_valid_array = false;
23   CeedVector_Cuda *impl;
24 
25   CeedCallBackend(CeedVectorGetData(vec, &impl));
26   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
27   switch (mem_type) {
28     case CEED_MEM_HOST:
29       *need_sync = has_valid_array && !impl->h_array;
30       break;
31     case CEED_MEM_DEVICE:
32       *need_sync = has_valid_array && !impl->d_array;
33       break;
34   }
35   return CEED_ERROR_SUCCESS;
36 }
37 
38 //------------------------------------------------------------------------------
39 // Sync host to device
40 //------------------------------------------------------------------------------
41 static inline int CeedVectorSyncH2D_Cuda(const CeedVector vec) {
42   CeedSize         length;
43   size_t           bytes;
44   CeedVector_Cuda *impl;
45 
46   CeedCallBackend(CeedVectorGetData(vec, &impl));
47 
48   CeedCheck(impl->h_array, CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "No valid host data to sync to device");
49 
50   CeedCallBackend(CeedVectorGetLength(vec, &length));
51   bytes = length * sizeof(CeedScalar);
52   if (impl->d_array_borrowed) {
53     impl->d_array = impl->d_array_borrowed;
54   } else if (impl->d_array_owned) {
55     impl->d_array = impl->d_array_owned;
56   } else {
57     CeedCallCuda(CeedVectorReturnCeed(vec), cudaMalloc((void **)&impl->d_array_owned, bytes));
58     impl->d_array = impl->d_array_owned;
59   }
60   CeedCallCuda(CeedVectorReturnCeed(vec), cudaMemcpy(impl->d_array, impl->h_array, bytes, cudaMemcpyHostToDevice));
61   return CEED_ERROR_SUCCESS;
62 }
63 
64 //------------------------------------------------------------------------------
65 // Sync device to host
66 //------------------------------------------------------------------------------
67 static inline int CeedVectorSyncD2H_Cuda(const CeedVector vec) {
68   CeedSize         length;
69   CeedVector_Cuda *impl;
70 
71   CeedCallBackend(CeedVectorGetData(vec, &impl));
72 
73   CeedCheck(impl->d_array, CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "No valid device data to sync to host");
74 
75   if (impl->h_array_borrowed) {
76     impl->h_array = impl->h_array_borrowed;
77   } else if (impl->h_array_owned) {
78     impl->h_array = impl->h_array_owned;
79   } else {
80     CeedSize length;
81 
82     CeedCallBackend(CeedVectorGetLength(vec, &length));
83     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
84     impl->h_array = impl->h_array_owned;
85   }
86 
87   CeedCallBackend(CeedVectorGetLength(vec, &length));
88   size_t bytes = length * sizeof(CeedScalar);
89 
90   CeedCallCuda(CeedVectorReturnCeed(vec), cudaMemcpy(impl->h_array, impl->d_array, bytes, cudaMemcpyDeviceToHost));
91   return CEED_ERROR_SUCCESS;
92 }
93 
94 //------------------------------------------------------------------------------
95 // Sync arrays
96 //------------------------------------------------------------------------------
97 static int CeedVectorSyncArray_Cuda(const CeedVector vec, CeedMemType mem_type) {
98   bool need_sync = false;
99 
100   // Check whether device/host sync is needed
101   CeedCallBackend(CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync));
102   if (!need_sync) return CEED_ERROR_SUCCESS;
103 
104   switch (mem_type) {
105     case CEED_MEM_HOST:
106       return CeedVectorSyncD2H_Cuda(vec);
107     case CEED_MEM_DEVICE:
108       return CeedVectorSyncH2D_Cuda(vec);
109   }
110   return CEED_ERROR_UNSUPPORTED;
111 }
112 
113 //------------------------------------------------------------------------------
114 // Set all pointers as invalid
115 //------------------------------------------------------------------------------
116 static inline int CeedVectorSetAllInvalid_Cuda(const CeedVector vec) {
117   CeedVector_Cuda *impl;
118 
119   CeedCallBackend(CeedVectorGetData(vec, &impl));
120   impl->h_array = NULL;
121   impl->d_array = NULL;
122   return CEED_ERROR_SUCCESS;
123 }
124 
125 //------------------------------------------------------------------------------
126 // Check if CeedVector has any valid pointer
127 //------------------------------------------------------------------------------
128 static inline int CeedVectorHasValidArray_Cuda(const CeedVector vec, bool *has_valid_array) {
129   CeedVector_Cuda *impl;
130 
131   CeedCallBackend(CeedVectorGetData(vec, &impl));
132   *has_valid_array = impl->h_array || impl->d_array;
133   return CEED_ERROR_SUCCESS;
134 }
135 
136 //------------------------------------------------------------------------------
137 // Check if has array of given type
138 //------------------------------------------------------------------------------
139 static inline int CeedVectorHasArrayOfType_Cuda(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
140   CeedVector_Cuda *impl;
141 
142   CeedCallBackend(CeedVectorGetData(vec, &impl));
143   switch (mem_type) {
144     case CEED_MEM_HOST:
145       *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned;
146       break;
147     case CEED_MEM_DEVICE:
148       *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned;
149       break;
150   }
151   return CEED_ERROR_SUCCESS;
152 }
153 
154 //------------------------------------------------------------------------------
155 // Check if has borrowed array of given type
156 //------------------------------------------------------------------------------
157 static inline int CeedVectorHasBorrowedArrayOfType_Cuda(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
158   CeedVector_Cuda *impl;
159 
160   CeedCallBackend(CeedVectorGetData(vec, &impl));
161   switch (mem_type) {
162     case CEED_MEM_HOST:
163       *has_borrowed_array_of_type = impl->h_array_borrowed;
164       break;
165     case CEED_MEM_DEVICE:
166       *has_borrowed_array_of_type = impl->d_array_borrowed;
167       break;
168   }
169   return CEED_ERROR_SUCCESS;
170 }
171 
172 //------------------------------------------------------------------------------
173 // Set array from host
174 //------------------------------------------------------------------------------
175 static int CeedVectorSetArrayHost_Cuda(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
176   CeedSize         length;
177   CeedVector_Cuda *impl;
178 
179   CeedCallBackend(CeedVectorGetData(vec, &impl));
180   CeedCallBackend(CeedVectorGetLength(vec, &length));
181 
182   CeedCallBackend(CeedSetHostCeedScalarArray(array, copy_mode, length, (const CeedScalar **)&impl->h_array_owned,
183                                              (const CeedScalar **)&impl->h_array_borrowed, (const CeedScalar **)&impl->h_array));
184   return CEED_ERROR_SUCCESS;
185 }
186 
187 //------------------------------------------------------------------------------
188 // Set array from device
189 //------------------------------------------------------------------------------
190 static int CeedVectorSetArrayDevice_Cuda(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
191   CeedSize         length;
192   Ceed             ceed;
193   CeedVector_Cuda *impl;
194 
195   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
196   CeedCallBackend(CeedVectorGetData(vec, &impl));
197   CeedCallBackend(CeedVectorGetLength(vec, &length));
198 
199   CeedCallBackend(CeedSetDeviceCeedScalarArray_Cuda(ceed, array, copy_mode, length, (const CeedScalar **)&impl->d_array_owned,
200                                                     (const CeedScalar **)&impl->d_array_borrowed, (const CeedScalar **)&impl->d_array));
201   CeedCallBackend(CeedDestroy(&ceed));
202   return CEED_ERROR_SUCCESS;
203 }
204 
205 //------------------------------------------------------------------------------
206 // Set the array used by a vector,
207 //   freeing any previously allocated array if applicable
208 //------------------------------------------------------------------------------
209 static int CeedVectorSetArray_Cuda(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
210   CeedVector_Cuda *impl;
211 
212   CeedCallBackend(CeedVectorGetData(vec, &impl));
213   CeedCallBackend(CeedVectorSetAllInvalid_Cuda(vec));
214   switch (mem_type) {
215     case CEED_MEM_HOST:
216       return CeedVectorSetArrayHost_Cuda(vec, copy_mode, array);
217     case CEED_MEM_DEVICE:
218       return CeedVectorSetArrayDevice_Cuda(vec, copy_mode, array);
219   }
220   return CEED_ERROR_UNSUPPORTED;
221 }
222 
223 //------------------------------------------------------------------------------
224 // Copy host array to value strided
225 //------------------------------------------------------------------------------
226 static int CeedHostCopyStrided_Cuda(CeedScalar *h_array, CeedSize start, CeedSize stop, CeedSize step, CeedScalar *h_copy_array) {
227   for (CeedSize i = start; i < stop; i += step) h_copy_array[i] = h_array[i];
228   return CEED_ERROR_SUCCESS;
229 }
230 
231 //------------------------------------------------------------------------------
232 // Copy device array to value strided (impl in .cu file)
233 //------------------------------------------------------------------------------
234 int CeedDeviceCopyStrided_Cuda(CeedScalar *d_array, CeedSize start, CeedSize stop, CeedSize step, CeedScalar *d_copy_array);
235 
236 //------------------------------------------------------------------------------
237 // Copy a vector to a value strided
238 //------------------------------------------------------------------------------
239 static int CeedVectorCopyStrided_Cuda(CeedVector vec, CeedSize start, CeedSize stop, CeedSize step, CeedVector vec_copy) {
240   CeedSize         length;
241   CeedVector_Cuda *impl;
242 
243   CeedCallBackend(CeedVectorGetData(vec, &impl));
244   {
245     CeedSize length_vec, length_copy;
246 
247     CeedCallBackend(CeedVectorGetLength(vec, &length_vec));
248     CeedCallBackend(CeedVectorGetLength(vec_copy, &length_copy));
249     length = length_vec < length_copy ? length_vec : length_copy;
250   }
251   if (stop == -1) stop = length;
252   // Set value for synced device/host array
253   if (impl->d_array) {
254     CeedScalar *copy_array;
255 
256     CeedCallBackend(CeedVectorGetArray(vec_copy, CEED_MEM_DEVICE, &copy_array));
257 #if (CUDA_VERSION >= 12000)
258     cublasHandle_t handle;
259     Ceed           ceed;
260 
261     CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
262     CeedCallBackend(CeedGetCublasHandle_Cuda(ceed, &handle));
263 #if defined(CEED_SCALAR_IS_FP32)
264     CeedCallCublas(ceed, cublasScopy_64(handle, (int64_t)(stop - start), impl->d_array + start, (int64_t)step, copy_array + start, (int64_t)step));
265 #else  /* CEED_SCALAR */
266     CeedCallCublas(ceed, cublasDcopy_64(handle, (int64_t)(stop - start), impl->d_array + start, (int64_t)step, copy_array + start, (int64_t)step));
267 #endif /* CEED_SCALAR */
268     CeedCallBackend(CeedDestroy(&ceed));
269 #else  /* CUDA_VERSION */
270     CeedCallBackend(CeedDeviceCopyStrided_Cuda(impl->d_array, start, stop, step, copy_array));
271 #endif /* CUDA_VERSION */
272     CeedCallBackend(CeedVectorRestoreArray(vec_copy, &copy_array));
273     impl->h_array = NULL;
274   } else if (impl->h_array) {
275     CeedScalar *copy_array;
276 
277     CeedCallBackend(CeedVectorGetArray(vec_copy, CEED_MEM_HOST, &copy_array));
278     CeedCallBackend(CeedHostCopyStrided_Cuda(impl->h_array, start, stop, step, copy_array));
279     CeedCallBackend(CeedVectorRestoreArray(vec_copy, &copy_array));
280     impl->d_array = NULL;
281   } else {
282     return CeedError(CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "CeedVector must have valid data set");
283   }
284   return CEED_ERROR_SUCCESS;
285 }
286 
287 //------------------------------------------------------------------------------
288 // Set host array to value
289 //------------------------------------------------------------------------------
290 static int CeedHostSetValue_Cuda(CeedScalar *h_array, CeedSize length, CeedScalar val) {
291   for (CeedSize i = 0; i < length; i++) h_array[i] = val;
292   return CEED_ERROR_SUCCESS;
293 }
294 
295 //------------------------------------------------------------------------------
296 // Set device array to value (impl in .cu file)
297 //------------------------------------------------------------------------------
298 int CeedDeviceSetValue_Cuda(CeedScalar *d_array, CeedSize length, CeedScalar val);
299 
300 //------------------------------------------------------------------------------
301 // Set a vector to a value
302 //------------------------------------------------------------------------------
303 static int CeedVectorSetValue_Cuda(CeedVector vec, CeedScalar val) {
304   CeedSize         length;
305   CeedVector_Cuda *impl;
306 
307   CeedCallBackend(CeedVectorGetData(vec, &impl));
308   CeedCallBackend(CeedVectorGetLength(vec, &length));
309   // Set value for synced device/host array
310   if (!impl->d_array && !impl->h_array) {
311     if (impl->d_array_borrowed) {
312       impl->d_array = impl->d_array_borrowed;
313     } else if (impl->h_array_borrowed) {
314       impl->h_array = impl->h_array_borrowed;
315     } else if (impl->d_array_owned) {
316       impl->d_array = impl->d_array_owned;
317     } else if (impl->h_array_owned) {
318       impl->h_array = impl->h_array_owned;
319     } else {
320       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
321     }
322   }
323   if (impl->d_array) {
324     if (val == 0) {
325       CeedCallCuda(CeedVectorReturnCeed(vec), cudaMemset(impl->d_array, 0, length * sizeof(CeedScalar)));
326     } else {
327       CeedCallBackend(CeedDeviceSetValue_Cuda(impl->d_array, length, val));
328     }
329     impl->h_array = NULL;
330   } else if (impl->h_array) {
331     CeedCallBackend(CeedHostSetValue_Cuda(impl->h_array, length, val));
332     impl->d_array = NULL;
333   }
334   return CEED_ERROR_SUCCESS;
335 }
336 
337 //------------------------------------------------------------------------------
338 // Set host array to value strided
339 //------------------------------------------------------------------------------
340 static int CeedHostSetValueStrided_Cuda(CeedScalar *h_array, CeedSize start, CeedSize stop, CeedSize step, CeedScalar val) {
341   for (CeedSize i = start; i < stop; i += step) h_array[i] = val;
342   return CEED_ERROR_SUCCESS;
343 }
344 
345 //------------------------------------------------------------------------------
346 // Set device array to value strided (impl in .cu file)
347 //------------------------------------------------------------------------------
348 int CeedDeviceSetValueStrided_Cuda(CeedScalar *d_array, CeedSize start, CeedSize stop, CeedSize step, CeedScalar val);
349 
350 //------------------------------------------------------------------------------
351 // Set a vector to a value strided
352 //------------------------------------------------------------------------------
353 static int CeedVectorSetValueStrided_Cuda(CeedVector vec, CeedSize start, CeedSize stop, CeedSize step, CeedScalar val) {
354   CeedSize         length;
355   CeedVector_Cuda *impl;
356 
357   CeedCallBackend(CeedVectorGetData(vec, &impl));
358   CeedCallBackend(CeedVectorGetLength(vec, &length));
359   // Set value for synced device/host array
360   if (stop == -1) stop = length;
361   if (impl->d_array) {
362     CeedCallBackend(CeedDeviceSetValueStrided_Cuda(impl->d_array, start, stop, step, val));
363     impl->h_array = NULL;
364   } else if (impl->h_array) {
365     CeedCallBackend(CeedHostSetValueStrided_Cuda(impl->h_array, start, stop, step, val));
366     impl->d_array = NULL;
367   } else {
368     return CeedError(CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "CeedVector must have valid data set");
369   }
370   return CEED_ERROR_SUCCESS;
371 }
372 
373 //------------------------------------------------------------------------------
374 // Vector Take Array
375 //------------------------------------------------------------------------------
376 static int CeedVectorTakeArray_Cuda(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
377   CeedVector_Cuda *impl;
378 
379   CeedCallBackend(CeedVectorGetData(vec, &impl));
380   // Sync array to requested mem_type
381   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
382   // Update pointer
383   switch (mem_type) {
384     case CEED_MEM_HOST:
385       (*array)               = impl->h_array_borrowed;
386       impl->h_array_borrowed = NULL;
387       impl->h_array          = NULL;
388       break;
389     case CEED_MEM_DEVICE:
390       (*array)               = impl->d_array_borrowed;
391       impl->d_array_borrowed = NULL;
392       impl->d_array          = NULL;
393       break;
394   }
395   return CEED_ERROR_SUCCESS;
396 }
397 
398 //------------------------------------------------------------------------------
399 // Core logic for array syncronization for GetArray.
400 //   If a different memory type is most up to date, this will perform a copy
401 //------------------------------------------------------------------------------
402 static int CeedVectorGetArrayCore_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
403   CeedVector_Cuda *impl;
404 
405   CeedCallBackend(CeedVectorGetData(vec, &impl));
406   // Sync array to requested mem_type
407   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
408   // Update pointer
409   switch (mem_type) {
410     case CEED_MEM_HOST:
411       *array = impl->h_array;
412       break;
413     case CEED_MEM_DEVICE:
414       *array = impl->d_array;
415       break;
416   }
417   return CEED_ERROR_SUCCESS;
418 }
419 
420 //------------------------------------------------------------------------------
421 // Get read-only access to a vector via the specified mem_type
422 //------------------------------------------------------------------------------
423 static int CeedVectorGetArrayRead_Cuda(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
424   return CeedVectorGetArrayCore_Cuda(vec, mem_type, (CeedScalar **)array);
425 }
426 
427 //------------------------------------------------------------------------------
428 // Get read/write access to a vector via the specified mem_type
429 //------------------------------------------------------------------------------
430 static int CeedVectorGetArray_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
431   CeedVector_Cuda *impl;
432 
433   CeedCallBackend(CeedVectorGetData(vec, &impl));
434   CeedCallBackend(CeedVectorGetArrayCore_Cuda(vec, mem_type, array));
435   CeedCallBackend(CeedVectorSetAllInvalid_Cuda(vec));
436   switch (mem_type) {
437     case CEED_MEM_HOST:
438       impl->h_array = *array;
439       break;
440     case CEED_MEM_DEVICE:
441       impl->d_array = *array;
442       break;
443   }
444   return CEED_ERROR_SUCCESS;
445 }
446 
447 //------------------------------------------------------------------------------
448 // Get write access to a vector via the specified mem_type
449 //------------------------------------------------------------------------------
450 static int CeedVectorGetArrayWrite_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
451   bool             has_array_of_type = true;
452   CeedVector_Cuda *impl;
453 
454   CeedCallBackend(CeedVectorGetData(vec, &impl));
455   CeedCallBackend(CeedVectorHasArrayOfType_Cuda(vec, mem_type, &has_array_of_type));
456   if (!has_array_of_type) {
457     // Allocate if array is not yet allocated
458     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
459   } else {
460     // Select dirty array
461     switch (mem_type) {
462       case CEED_MEM_HOST:
463         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
464         else impl->h_array = impl->h_array_owned;
465         break;
466       case CEED_MEM_DEVICE:
467         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
468         else impl->d_array = impl->d_array_owned;
469     }
470   }
471   return CeedVectorGetArray_Cuda(vec, mem_type, array);
472 }
473 
474 //------------------------------------------------------------------------------
475 // Get the norm of a CeedVector
476 //------------------------------------------------------------------------------
477 static int CeedVectorNorm_Cuda(CeedVector vec, CeedNormType type, CeedScalar *norm) {
478   Ceed     ceed;
479   CeedSize length;
480 #if (CUDA_VERSION < 12000)
481   CeedSize num_calls;
482 #endif /* CUDA_VERSION */
483   const CeedScalar *d_array;
484   CeedVector_Cuda  *impl;
485   cublasHandle_t    handle;
486 
487   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
488   CeedCallBackend(CeedVectorGetData(vec, &impl));
489   CeedCallBackend(CeedVectorGetLength(vec, &length));
490   CeedCallBackend(CeedGetCublasHandle_Cuda(ceed, &handle));
491 
492 #if (CUDA_VERSION < 12000)
493   // With CUDA 12, we can use the 64-bit integer interface. Prior to that,
494   // we need to check if the vector is too long to handle with int32,
495   // and if so, divide it into subsections for repeated cuBLAS calls.
496   num_calls = length / INT_MAX;
497   if (length % INT_MAX > 0) num_calls += 1;
498 #endif /* CUDA_VERSION */
499 
500   // Compute norm
501   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
502   switch (type) {
503     case CEED_NORM_1: {
504       *norm = 0.0;
505 #if defined(CEED_SCALAR_IS_FP32)
506 #if (CUDA_VERSION >= 12000)  // We have CUDA 12, and can use 64-bit integers
507       CeedCallCublas(ceed, cublasSasum_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
508 #else  /* CUDA_VERSION */
509       float  sub_norm = 0.0;
510       float *d_array_start;
511 
512       for (CeedInt i = 0; i < num_calls; i++) {
513         d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
514         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
515         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
516 
517         CeedCallCublas(ceed, cublasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
518         *norm += sub_norm;
519       }
520 #endif /* CUDA_VERSION */
521 #else  /* CEED_SCALAR */
522 #if (CUDA_VERSION >= 12000)
523       CeedCallCublas(ceed, cublasDasum_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
524 #else  /* CUDA_VERSION */
525       double  sub_norm = 0.0;
526       double *d_array_start;
527 
528       for (CeedInt i = 0; i < num_calls; i++) {
529         d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
530         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
531         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
532 
533         CeedCallCublas(ceed, cublasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
534         *norm += sub_norm;
535       }
536 #endif /* CUDA_VERSION */
537 #endif /* CEED_SCALAR */
538       break;
539     }
540     case CEED_NORM_2: {
541 #if defined(CEED_SCALAR_IS_FP32)
542 #if (CUDA_VERSION >= 12000)
543       CeedCallCublas(ceed, cublasSnrm2_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
544 #else  /* CUDA_VERSION */
545       float  sub_norm = 0.0, norm_sum = 0.0;
546       float *d_array_start;
547 
548       for (CeedInt i = 0; i < num_calls; i++) {
549         d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
550         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
551         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
552 
553         CeedCallCublas(ceed, cublasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
554         norm_sum += sub_norm * sub_norm;
555       }
556       *norm = sqrt(norm_sum);
557 #endif /* CUDA_VERSION */
558 #else  /* CEED_SCALAR */
559 #if (CUDA_VERSION >= 12000)
560       CeedCallCublas(ceed, cublasDnrm2_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
561 #else  /* CUDA_VERSION */
562       double  sub_norm = 0.0, norm_sum = 0.0;
563       double *d_array_start;
564 
565       for (CeedInt i = 0; i < num_calls; i++) {
566         d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
567         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
568         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
569 
570         CeedCallCublas(ceed, cublasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
571         norm_sum += sub_norm * sub_norm;
572       }
573       *norm = sqrt(norm_sum);
574 #endif /* CUDA_VERSION */
575 #endif /* CEED_SCALAR */
576       break;
577     }
578     case CEED_NORM_MAX: {
579 #if defined(CEED_SCALAR_IS_FP32)
580 #if (CUDA_VERSION >= 12000)
581       int64_t    index;
582       CeedScalar norm_no_abs;
583 
584       CeedCallCublas(ceed, cublasIsamax_64(handle, (int64_t)length, (float *)d_array, 1, &index));
585       CeedCallCuda(ceed, cudaMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
586       *norm = fabs(norm_no_abs);
587 #else  /* CUDA_VERSION */
588       CeedInt index;
589       float   sub_max = 0.0, current_max = 0.0;
590       float  *d_array_start;
591 
592       for (CeedInt i = 0; i < num_calls; i++) {
593         d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
594         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
595         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
596 
597         CeedCallCublas(ceed, cublasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &index));
598         CeedCallCuda(ceed, cudaMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
599         if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
600       }
601       *norm = current_max;
602 #endif /* CUDA_VERSION */
603 #else  /* CEED_SCALAR */
604 #if (CUDA_VERSION >= 12000)
605       int64_t    index;
606       CeedScalar norm_no_abs;
607 
608       CeedCallCublas(ceed, cublasIdamax_64(handle, (int64_t)length, (double *)d_array, 1, &index));
609       CeedCallCuda(ceed, cudaMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
610       *norm = fabs(norm_no_abs);
611 #else  /* CUDA_VERSION */
612       CeedInt index;
613       double  sub_max = 0.0, current_max = 0.0;
614       double *d_array_start;
615 
616       for (CeedInt i = 0; i < num_calls; i++) {
617         d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
618         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
619         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
620 
621         CeedCallCublas(ceed, cublasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &index));
622         CeedCallCuda(ceed, cudaMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
623         if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
624       }
625       *norm = current_max;
626 #endif /* CUDA_VERSION */
627 #endif /* CEED_SCALAR */
628       break;
629     }
630   }
631   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
632   CeedCallBackend(CeedDestroy(&ceed));
633   return CEED_ERROR_SUCCESS;
634 }
635 
636 //------------------------------------------------------------------------------
637 // Take reciprocal of a vector on host
638 //------------------------------------------------------------------------------
639 static int CeedHostReciprocal_Cuda(CeedScalar *h_array, CeedSize length) {
640   for (CeedSize i = 0; i < length; i++) {
641     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
642   }
643   return CEED_ERROR_SUCCESS;
644 }
645 
646 //------------------------------------------------------------------------------
647 // Take reciprocal of a vector on device (impl in .cu file)
648 //------------------------------------------------------------------------------
649 int CeedDeviceReciprocal_Cuda(CeedScalar *d_array, CeedSize length);
650 
651 //------------------------------------------------------------------------------
652 // Take reciprocal of a vector
653 //------------------------------------------------------------------------------
654 static int CeedVectorReciprocal_Cuda(CeedVector vec) {
655   CeedSize         length;
656   CeedVector_Cuda *impl;
657 
658   CeedCallBackend(CeedVectorGetData(vec, &impl));
659   CeedCallBackend(CeedVectorGetLength(vec, &length));
660   // Set value for synced device/host array
661   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Cuda(impl->d_array, length));
662   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Cuda(impl->h_array, length));
663   return CEED_ERROR_SUCCESS;
664 }
665 
666 //------------------------------------------------------------------------------
667 // Compute x = alpha x on the host
668 //------------------------------------------------------------------------------
669 static int CeedHostScale_Cuda(CeedScalar *x_array, CeedScalar alpha, CeedSize length) {
670   for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha;
671   return CEED_ERROR_SUCCESS;
672 }
673 
674 //------------------------------------------------------------------------------
675 // Compute x = alpha x on device (impl in .cu file)
676 //------------------------------------------------------------------------------
677 int CeedDeviceScale_Cuda(CeedScalar *x_array, CeedScalar alpha, CeedSize length);
678 
679 //------------------------------------------------------------------------------
680 // Compute x = alpha x
681 //------------------------------------------------------------------------------
682 static int CeedVectorScale_Cuda(CeedVector x, CeedScalar alpha) {
683   CeedSize         length;
684   CeedVector_Cuda *impl;
685 
686   CeedCallBackend(CeedVectorGetData(x, &impl));
687   CeedCallBackend(CeedVectorGetLength(x, &length));
688   // Set value for synced device/host array
689   if (impl->d_array) {
690 #if (CUDA_VERSION >= 12000)
691     cublasHandle_t handle;
692 
693     CeedCallBackend(CeedGetCublasHandle_Cuda(CeedVectorReturnCeed(x), &handle));
694 #if defined(CEED_SCALAR_IS_FP32)
695     CeedCallCublas(CeedVectorReturnCeed(x), cublasSscal_64(handle, (int64_t)length, &alpha, impl->d_array, 1));
696 #else  /* CEED_SCALAR */
697     CeedCallCublas(CeedVectorReturnCeed(x), cublasDscal_64(handle, (int64_t)length, &alpha, impl->d_array, 1));
698 #endif /* CEED_SCALAR */
699 #else  /* CUDA_VERSION */
700     CeedCallBackend(CeedDeviceScale_Cuda(impl->d_array, alpha, length));
701 #endif /* CUDA_VERSION */
702     impl->h_array = NULL;
703   } else if (impl->h_array) {
704     CeedCallBackend(CeedHostScale_Cuda(impl->h_array, alpha, length));
705     impl->d_array = NULL;
706   }
707   return CEED_ERROR_SUCCESS;
708 }
709 
710 //------------------------------------------------------------------------------
711 // Compute y = alpha x + y on the host
712 //------------------------------------------------------------------------------
713 static int CeedHostAXPY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) {
714   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
715   return CEED_ERROR_SUCCESS;
716 }
717 
718 //------------------------------------------------------------------------------
719 // Compute y = alpha x + y on device (impl in .cu file)
720 //------------------------------------------------------------------------------
721 int CeedDeviceAXPY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length);
722 
723 //------------------------------------------------------------------------------
724 // Compute y = alpha x + y
725 //------------------------------------------------------------------------------
726 static int CeedVectorAXPY_Cuda(CeedVector y, CeedScalar alpha, CeedVector x) {
727   CeedSize         length;
728   CeedVector_Cuda *y_impl, *x_impl;
729 
730   CeedCallBackend(CeedVectorGetData(y, &y_impl));
731   CeedCallBackend(CeedVectorGetData(x, &x_impl));
732   CeedCallBackend(CeedVectorGetLength(y, &length));
733   // Set value for synced device/host array
734   if (y_impl->d_array) {
735     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
736 #if (CUDA_VERSION >= 12000)
737     cublasHandle_t handle;
738 
739     CeedCallBackend(CeedGetCublasHandle_Cuda(CeedVectorReturnCeed(y), &handle));
740 #if defined(CEED_SCALAR_IS_FP32)
741     CeedCallCublas(CeedVectorReturnCeed(y), cublasSaxpy_64(handle, (int64_t)length, &alpha, x_impl->d_array, 1, y_impl->d_array, 1));
742 #else  /* CEED_SCALAR */
743     CeedCallCublas(CeedVectorReturnCeed(y), cublasDaxpy_64(handle, (int64_t)length, &alpha, x_impl->d_array, 1, y_impl->d_array, 1));
744 #endif /* CEED_SCALAR */
745 #else  /* CUDA_VERSION */
746     CeedCallBackend(CeedDeviceAXPY_Cuda(y_impl->d_array, alpha, x_impl->d_array, length));
747 #endif /* CUDA_VERSION */
748     y_impl->h_array = NULL;
749   } else if (y_impl->h_array) {
750     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
751     CeedCallBackend(CeedHostAXPY_Cuda(y_impl->h_array, alpha, x_impl->h_array, length));
752     y_impl->d_array = NULL;
753   }
754   return CEED_ERROR_SUCCESS;
755 }
756 
757 //------------------------------------------------------------------------------
758 // Compute y = alpha x + beta y on the host
759 //------------------------------------------------------------------------------
760 static int CeedHostAXPBY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) {
761   for (CeedSize i = 0; i < length; i++) y_array[i] = alpha * x_array[i] + beta * y_array[i];
762   return CEED_ERROR_SUCCESS;
763 }
764 
765 //------------------------------------------------------------------------------
766 // Compute y = alpha x + beta y on device (impl in .cu file)
767 //------------------------------------------------------------------------------
768 int CeedDeviceAXPBY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length);
769 
770 //------------------------------------------------------------------------------
771 // Compute y = alpha x + beta y
772 //------------------------------------------------------------------------------
773 static int CeedVectorAXPBY_Cuda(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
774   CeedSize         length;
775   CeedVector_Cuda *y_impl, *x_impl;
776 
777   CeedCallBackend(CeedVectorGetData(y, &y_impl));
778   CeedCallBackend(CeedVectorGetData(x, &x_impl));
779   CeedCallBackend(CeedVectorGetLength(y, &length));
780   // Set value for synced device/host array
781   if (y_impl->d_array) {
782     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
783     CeedCallBackend(CeedDeviceAXPBY_Cuda(y_impl->d_array, alpha, beta, x_impl->d_array, length));
784   }
785   if (y_impl->h_array) {
786     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
787     CeedCallBackend(CeedHostAXPBY_Cuda(y_impl->h_array, alpha, beta, x_impl->h_array, length));
788   }
789   return CEED_ERROR_SUCCESS;
790 }
791 
792 //------------------------------------------------------------------------------
793 // Compute the pointwise multiplication w = x .* y on the host
794 //------------------------------------------------------------------------------
795 static int CeedHostPointwiseMult_Cuda(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) {
796   for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
797   return CEED_ERROR_SUCCESS;
798 }
799 
800 //------------------------------------------------------------------------------
801 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
802 //------------------------------------------------------------------------------
803 int CeedDevicePointwiseMult_Cuda(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length);
804 
805 //------------------------------------------------------------------------------
806 // Compute the pointwise multiplication w = x .* y
807 //------------------------------------------------------------------------------
808 static int CeedVectorPointwiseMult_Cuda(CeedVector w, CeedVector x, CeedVector y) {
809   CeedSize         length;
810   CeedVector_Cuda *w_impl, *x_impl, *y_impl;
811 
812   CeedCallBackend(CeedVectorGetData(w, &w_impl));
813   CeedCallBackend(CeedVectorGetData(x, &x_impl));
814   CeedCallBackend(CeedVectorGetData(y, &y_impl));
815   CeedCallBackend(CeedVectorGetLength(w, &length));
816   // Set value for synced device/host array
817   if (!w_impl->d_array && !w_impl->h_array) {
818     CeedCallBackend(CeedVectorSetValue(w, 0.0));
819   }
820   if (w_impl->d_array) {
821     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
822     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
823     CeedCallBackend(CeedDevicePointwiseMult_Cuda(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
824   }
825   if (w_impl->h_array) {
826     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
827     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
828     CeedCallBackend(CeedHostPointwiseMult_Cuda(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
829   }
830   return CEED_ERROR_SUCCESS;
831 }
832 
833 //------------------------------------------------------------------------------
834 // Destroy the vector
835 //------------------------------------------------------------------------------
836 static int CeedVectorDestroy_Cuda(const CeedVector vec) {
837   CeedVector_Cuda *impl;
838 
839   CeedCallBackend(CeedVectorGetData(vec, &impl));
840   CeedCallCuda(CeedVectorReturnCeed(vec), cudaFree(impl->d_array_owned));
841   CeedCallBackend(CeedFree(&impl->h_array_owned));
842   CeedCallBackend(CeedFree(&impl));
843   return CEED_ERROR_SUCCESS;
844 }
845 
846 //------------------------------------------------------------------------------
847 // Create a vector of the specified length (does not allocate memory)
848 //------------------------------------------------------------------------------
849 int CeedVectorCreate_Cuda(CeedSize n, CeedVector vec) {
850   CeedVector_Cuda *impl;
851   Ceed             ceed;
852 
853   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
854   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Cuda));
855   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Cuda));
856   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Cuda));
857   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Cuda));
858   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "CopyStrided", CeedVectorCopyStrided_Cuda));
859   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", CeedVectorSetValue_Cuda));
860   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValueStrided", CeedVectorSetValueStrided_Cuda));
861   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Cuda));
862   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Cuda));
863   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Cuda));
864   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Cuda));
865   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Cuda));
866   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Cuda));
867   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", CeedVectorScale_Cuda));
868   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", CeedVectorAXPY_Cuda));
869   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", CeedVectorAXPBY_Cuda));
870   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Cuda));
871   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Cuda));
872   CeedCallBackend(CeedDestroy(&ceed));
873   CeedCallBackend(CeedCalloc(1, &impl));
874   CeedCallBackend(CeedVectorSetData(vec, impl));
875   return CEED_ERROR_SUCCESS;
876 }
877 
878 //------------------------------------------------------------------------------
879