xref: /libCEED/backends/cuda-ref/ceed-cuda-ref-vector.c (revision 7113573b6efd54558bb98b919dff5d6d8ffcff54)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <cublas_v2.h>
11 #include <cuda_runtime.h>
12 #include <math.h>
13 #include <stdbool.h>
14 #include <string.h>
15 
16 #include "../cuda/ceed-cuda-common.h"
17 #include "ceed-cuda-ref.h"
18 
19 //------------------------------------------------------------------------------
20 // Check if host/device sync is needed
21 //------------------------------------------------------------------------------
22 static inline int CeedVectorNeedSync_Cuda(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
23   CeedVector_Cuda *impl;
24   CeedCallBackend(CeedVectorGetData(vec, &impl));
25 
26   bool has_valid_array = false;
27   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
28   switch (mem_type) {
29     case CEED_MEM_HOST:
30       *need_sync = has_valid_array && !impl->h_array;
31       break;
32     case CEED_MEM_DEVICE:
33       *need_sync = has_valid_array && !impl->d_array;
34       break;
35   }
36 
37   return CEED_ERROR_SUCCESS;
38 }
39 
40 //------------------------------------------------------------------------------
41 // Sync host to device
42 //------------------------------------------------------------------------------
43 static inline int CeedVectorSyncH2D_Cuda(const CeedVector vec) {
44   Ceed ceed;
45   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
46   CeedVector_Cuda *impl;
47   CeedCallBackend(CeedVectorGetData(vec, &impl));
48 
49   CeedCheck(impl->h_array, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
50 
51   CeedSize length;
52   CeedCallBackend(CeedVectorGetLength(vec, &length));
53   size_t bytes = length * sizeof(CeedScalar);
54 
55   if (impl->d_array_borrowed) {
56     impl->d_array = impl->d_array_borrowed;
57   } else if (impl->d_array_owned) {
58     impl->d_array = impl->d_array_owned;
59   } else {
60     CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_array_owned, bytes));
61     impl->d_array = impl->d_array_owned;
62   }
63 
64   CeedCallCuda(ceed, cudaMemcpy(impl->d_array, impl->h_array, bytes, cudaMemcpyHostToDevice));
65 
66   return CEED_ERROR_SUCCESS;
67 }
68 
69 //------------------------------------------------------------------------------
70 // Sync device to host
71 //------------------------------------------------------------------------------
72 static inline int CeedVectorSyncD2H_Cuda(const CeedVector vec) {
73   Ceed ceed;
74   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
75   CeedVector_Cuda *impl;
76   CeedCallBackend(CeedVectorGetData(vec, &impl));
77 
78   CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
79 
80   if (impl->h_array_borrowed) {
81     impl->h_array = impl->h_array_borrowed;
82   } else if (impl->h_array_owned) {
83     impl->h_array = impl->h_array_owned;
84   } else {
85     CeedSize length;
86     CeedCallBackend(CeedVectorGetLength(vec, &length));
87     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
88     impl->h_array = impl->h_array_owned;
89   }
90 
91   CeedSize length;
92   CeedCallBackend(CeedVectorGetLength(vec, &length));
93   size_t bytes = length * sizeof(CeedScalar);
94   CeedCallCuda(ceed, cudaMemcpy(impl->h_array, impl->d_array, bytes, cudaMemcpyDeviceToHost));
95 
96   return CEED_ERROR_SUCCESS;
97 }
98 
99 //------------------------------------------------------------------------------
100 // Sync arrays
101 //------------------------------------------------------------------------------
102 static int CeedVectorSyncArray_Cuda(const CeedVector vec, CeedMemType mem_type) {
103   // Check whether device/host sync is needed
104   bool need_sync = false;
105   CeedCallBackend(CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync));
106   if (!need_sync) return CEED_ERROR_SUCCESS;
107 
108   switch (mem_type) {
109     case CEED_MEM_HOST:
110       return CeedVectorSyncD2H_Cuda(vec);
111     case CEED_MEM_DEVICE:
112       return CeedVectorSyncH2D_Cuda(vec);
113   }
114   return CEED_ERROR_UNSUPPORTED;
115 }
116 
117 //------------------------------------------------------------------------------
118 // Set all pointers as invalid
119 //------------------------------------------------------------------------------
120 static inline int CeedVectorSetAllInvalid_Cuda(const CeedVector vec) {
121   CeedVector_Cuda *impl;
122   CeedCallBackend(CeedVectorGetData(vec, &impl));
123 
124   impl->h_array = NULL;
125   impl->d_array = NULL;
126 
127   return CEED_ERROR_SUCCESS;
128 }
129 
130 //------------------------------------------------------------------------------
131 // Check if CeedVector has any valid pointer
132 //------------------------------------------------------------------------------
133 static inline int CeedVectorHasValidArray_Cuda(const CeedVector vec, bool *has_valid_array) {
134   CeedVector_Cuda *impl;
135   CeedCallBackend(CeedVectorGetData(vec, &impl));
136 
137   *has_valid_array = !!impl->h_array || !!impl->d_array;
138 
139   return CEED_ERROR_SUCCESS;
140 }
141 
142 //------------------------------------------------------------------------------
143 // Check if has array of given type
144 //------------------------------------------------------------------------------
145 static inline int CeedVectorHasArrayOfType_Cuda(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
146   CeedVector_Cuda *impl;
147   CeedCallBackend(CeedVectorGetData(vec, &impl));
148 
149   switch (mem_type) {
150     case CEED_MEM_HOST:
151       *has_array_of_type = !!impl->h_array_borrowed || !!impl->h_array_owned;
152       break;
153     case CEED_MEM_DEVICE:
154       *has_array_of_type = !!impl->d_array_borrowed || !!impl->d_array_owned;
155       break;
156   }
157 
158   return CEED_ERROR_SUCCESS;
159 }
160 
161 //------------------------------------------------------------------------------
162 // Check if has borrowed array of given type
163 //------------------------------------------------------------------------------
164 static inline int CeedVectorHasBorrowedArrayOfType_Cuda(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
165   CeedVector_Cuda *impl;
166   CeedCallBackend(CeedVectorGetData(vec, &impl));
167 
168   switch (mem_type) {
169     case CEED_MEM_HOST:
170       *has_borrowed_array_of_type = !!impl->h_array_borrowed;
171       break;
172     case CEED_MEM_DEVICE:
173       *has_borrowed_array_of_type = !!impl->d_array_borrowed;
174       break;
175   }
176 
177   return CEED_ERROR_SUCCESS;
178 }
179 
180 //------------------------------------------------------------------------------
181 // Set array from host
182 //------------------------------------------------------------------------------
183 static int CeedVectorSetArrayHost_Cuda(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
184   CeedVector_Cuda *impl;
185   CeedCallBackend(CeedVectorGetData(vec, &impl));
186 
187   switch (copy_mode) {
188     case CEED_COPY_VALUES: {
189       CeedSize length;
190       if (!impl->h_array_owned) {
191         CeedCallBackend(CeedVectorGetLength(vec, &length));
192         CeedCallBackend(CeedMalloc(length, &impl->h_array_owned));
193       }
194       impl->h_array_borrowed = NULL;
195       impl->h_array          = impl->h_array_owned;
196       if (array) {
197         CeedSize length;
198         CeedCallBackend(CeedVectorGetLength(vec, &length));
199         size_t bytes = length * sizeof(CeedScalar);
200         memcpy(impl->h_array, array, bytes);
201       }
202     } break;
203     case CEED_OWN_POINTER:
204       CeedCallBackend(CeedFree(&impl->h_array_owned));
205       impl->h_array_owned    = array;
206       impl->h_array_borrowed = NULL;
207       impl->h_array          = array;
208       break;
209     case CEED_USE_POINTER:
210       CeedCallBackend(CeedFree(&impl->h_array_owned));
211       impl->h_array_borrowed = array;
212       impl->h_array          = array;
213       break;
214   }
215 
216   return CEED_ERROR_SUCCESS;
217 }
218 
219 //------------------------------------------------------------------------------
220 // Set array from device
221 //------------------------------------------------------------------------------
222 static int CeedVectorSetArrayDevice_Cuda(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
223   Ceed ceed;
224   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
225   CeedVector_Cuda *impl;
226   CeedCallBackend(CeedVectorGetData(vec, &impl));
227 
228   switch (copy_mode) {
229     case CEED_COPY_VALUES: {
230       CeedSize length;
231       CeedCallBackend(CeedVectorGetLength(vec, &length));
232       size_t bytes = length * sizeof(CeedScalar);
233       if (!impl->d_array_owned) {
234         CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_array_owned, bytes));
235         impl->d_array = impl->d_array_owned;
236       }
237       if (array) CeedCallCuda(ceed, cudaMemcpy(impl->d_array, array, bytes, cudaMemcpyDeviceToDevice));
238     } break;
239     case CEED_OWN_POINTER:
240       CeedCallCuda(ceed, cudaFree(impl->d_array_owned));
241       impl->d_array_owned    = array;
242       impl->d_array_borrowed = NULL;
243       impl->d_array          = array;
244       break;
245     case CEED_USE_POINTER:
246       CeedCallCuda(ceed, cudaFree(impl->d_array_owned));
247       impl->d_array_owned    = NULL;
248       impl->d_array_borrowed = array;
249       impl->d_array          = array;
250       break;
251   }
252 
253   return CEED_ERROR_SUCCESS;
254 }
255 
256 //------------------------------------------------------------------------------
257 // Set the array used by a vector,
258 //   freeing any previously allocated array if applicable
259 //------------------------------------------------------------------------------
260 static int CeedVectorSetArray_Cuda(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
261   Ceed ceed;
262   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
263   CeedVector_Cuda *impl;
264   CeedCallBackend(CeedVectorGetData(vec, &impl));
265 
266   CeedCallBackend(CeedVectorSetAllInvalid_Cuda(vec));
267   switch (mem_type) {
268     case CEED_MEM_HOST:
269       return CeedVectorSetArrayHost_Cuda(vec, copy_mode, array);
270     case CEED_MEM_DEVICE:
271       return CeedVectorSetArrayDevice_Cuda(vec, copy_mode, array);
272   }
273 
274   return CEED_ERROR_UNSUPPORTED;
275 }
276 
277 //------------------------------------------------------------------------------
278 // Set host array to value
279 //------------------------------------------------------------------------------
280 static int CeedHostSetValue_Cuda(CeedScalar *h_array, CeedSize length, CeedScalar val) {
281   for (CeedSize i = 0; i < length; i++) h_array[i] = val;
282   return CEED_ERROR_SUCCESS;
283 }
284 
285 //------------------------------------------------------------------------------
286 // Set device array to value (impl in .cu file)
287 //------------------------------------------------------------------------------
288 int CeedDeviceSetValue_Cuda(CeedScalar *d_array, CeedSize length, CeedScalar val);
289 
290 //------------------------------------------------------------------------------
291 // Set a vector to a value
292 //------------------------------------------------------------------------------
293 static int CeedVectorSetValue_Cuda(CeedVector vec, CeedScalar val) {
294   Ceed ceed;
295   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
296   CeedVector_Cuda *impl;
297   CeedCallBackend(CeedVectorGetData(vec, &impl));
298   CeedSize length;
299   CeedCallBackend(CeedVectorGetLength(vec, &length));
300 
301   // Set value for synced device/host array
302   if (!impl->d_array && !impl->h_array) {
303     if (impl->d_array_borrowed) {
304       impl->d_array = impl->d_array_borrowed;
305     } else if (impl->h_array_borrowed) {
306       impl->h_array = impl->h_array_borrowed;
307     } else if (impl->d_array_owned) {
308       impl->d_array = impl->d_array_owned;
309     } else if (impl->h_array_owned) {
310       impl->h_array = impl->h_array_owned;
311     } else {
312       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
313     }
314   }
315   if (impl->d_array) {
316     CeedCallBackend(CeedDeviceSetValue_Cuda(impl->d_array, length, val));
317     impl->h_array = NULL;
318   }
319   if (impl->h_array) {
320     CeedCallBackend(CeedHostSetValue_Cuda(impl->h_array, length, val));
321     impl->d_array = NULL;
322   }
323 
324   return CEED_ERROR_SUCCESS;
325 }
326 
327 //------------------------------------------------------------------------------
328 // Vector Take Array
329 //------------------------------------------------------------------------------
330 static int CeedVectorTakeArray_Cuda(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
331   Ceed ceed;
332   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
333   CeedVector_Cuda *impl;
334   CeedCallBackend(CeedVectorGetData(vec, &impl));
335 
336   // Sync array to requested mem_type
337   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
338 
339   // Update pointer
340   switch (mem_type) {
341     case CEED_MEM_HOST:
342       (*array)               = impl->h_array_borrowed;
343       impl->h_array_borrowed = NULL;
344       impl->h_array          = NULL;
345       break;
346     case CEED_MEM_DEVICE:
347       (*array)               = impl->d_array_borrowed;
348       impl->d_array_borrowed = NULL;
349       impl->d_array          = NULL;
350       break;
351   }
352 
353   return CEED_ERROR_SUCCESS;
354 }
355 
356 //------------------------------------------------------------------------------
357 // Core logic for array syncronization for GetArray.
358 //   If a different memory type is most up to date, this will perform a copy
359 //------------------------------------------------------------------------------
360 static int CeedVectorGetArrayCore_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
361   Ceed ceed;
362   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
363   CeedVector_Cuda *impl;
364   CeedCallBackend(CeedVectorGetData(vec, &impl));
365 
366   // Sync array to requested mem_type
367   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
368 
369   // Update pointer
370   switch (mem_type) {
371     case CEED_MEM_HOST:
372       *array = impl->h_array;
373       break;
374     case CEED_MEM_DEVICE:
375       *array = impl->d_array;
376       break;
377   }
378 
379   return CEED_ERROR_SUCCESS;
380 }
381 
382 //------------------------------------------------------------------------------
383 // Get read-only access to a vector via the specified mem_type
384 //------------------------------------------------------------------------------
385 static int CeedVectorGetArrayRead_Cuda(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
386   return CeedVectorGetArrayCore_Cuda(vec, mem_type, (CeedScalar **)array);
387 }
388 
389 //------------------------------------------------------------------------------
390 // Get read/write access to a vector via the specified mem_type
391 //------------------------------------------------------------------------------
392 static int CeedVectorGetArray_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
393   CeedVector_Cuda *impl;
394   CeedCallBackend(CeedVectorGetData(vec, &impl));
395 
396   CeedCallBackend(CeedVectorGetArrayCore_Cuda(vec, mem_type, array));
397 
398   CeedCallBackend(CeedVectorSetAllInvalid_Cuda(vec));
399   switch (mem_type) {
400     case CEED_MEM_HOST:
401       impl->h_array = *array;
402       break;
403     case CEED_MEM_DEVICE:
404       impl->d_array = *array;
405       break;
406   }
407 
408   return CEED_ERROR_SUCCESS;
409 }
410 
411 //------------------------------------------------------------------------------
412 // Get write access to a vector via the specified mem_type
413 //------------------------------------------------------------------------------
414 static int CeedVectorGetArrayWrite_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
415   CeedVector_Cuda *impl;
416   CeedCallBackend(CeedVectorGetData(vec, &impl));
417 
418   bool has_array_of_type = true;
419   CeedCallBackend(CeedVectorHasArrayOfType_Cuda(vec, mem_type, &has_array_of_type));
420   if (!has_array_of_type) {
421     // Allocate if array is not yet allocated
422     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
423   } else {
424     // Select dirty array
425     switch (mem_type) {
426       case CEED_MEM_HOST:
427         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
428         else impl->h_array = impl->h_array_owned;
429         break;
430       case CEED_MEM_DEVICE:
431         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
432         else impl->d_array = impl->d_array_owned;
433     }
434   }
435 
436   return CeedVectorGetArray_Cuda(vec, mem_type, array);
437 }
438 
439 //------------------------------------------------------------------------------
440 // Get the norm of a CeedVector
441 //------------------------------------------------------------------------------
442 static int CeedVectorNorm_Cuda(CeedVector vec, CeedNormType type, CeedScalar *norm) {
443   Ceed ceed;
444   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
445   CeedVector_Cuda *impl;
446   CeedCallBackend(CeedVectorGetData(vec, &impl));
447   CeedSize length;
448   CeedCallBackend(CeedVectorGetLength(vec, &length));
449   cublasHandle_t handle;
450   CeedCallBackend(CeedGetCublasHandle_Cuda(ceed, &handle));
451 
452 #if CUDA_VERSION < 12000
453   // With CUDA 12, we can use the 64-bit integer interface. Prior to that,
454   // we need to check if the vector is too long to handle with int32,
455   // and if so, divide it into subsections for repeated cuBLAS calls
456   CeedSize num_calls = length / INT_MAX;
457   if (length % INT_MAX > 0) num_calls += 1;
458 #endif
459 
460   // Compute norm
461   const CeedScalar *d_array;
462   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
463   switch (type) {
464     case CEED_NORM_1: {
465       *norm = 0.0;
466       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
467 #if CUDA_VERSION >= 12000  // We have CUDA 12, and can use 64-bit integers
468         CeedCallCublas(ceed, cublasSasum_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
469 #else
470         float  sub_norm = 0.0;
471         float *d_array_start;
472         for (CeedInt i = 0; i < num_calls; i++) {
473           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
474           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
475           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
476           CeedCallCublas(ceed, cublasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
477           *norm += sub_norm;
478         }
479 #endif
480       } else {
481 #if CUDA_VERSION >= 12000
482         CeedCallCublas(ceed, cublasDasum_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
483 #else
484         double  sub_norm = 0.0;
485         double *d_array_start;
486         for (CeedInt i = 0; i < num_calls; i++) {
487           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
488           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
489           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
490           CeedCallCublas(ceed, cublasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
491           *norm += sub_norm;
492         }
493 #endif
494       }
495       break;
496     }
497     case CEED_NORM_2: {
498       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
499 #if CUDA_VERSION >= 12000
500         CeedCallCublas(ceed, cublasSnrm2_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
501 #else
502         float  sub_norm = 0.0, norm_sum = 0.0;
503         float *d_array_start;
504         for (CeedInt i = 0; i < num_calls; i++) {
505           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
506           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
507           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
508           CeedCallCublas(ceed, cublasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
509           norm_sum += sub_norm * sub_norm;
510         }
511         *norm            = sqrt(norm_sum);
512 #endif
513       } else {
514 #if CUDA_VERSION >= 12000
515         CeedCallCublas(ceed, cublasDnrm2_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
516 #else
517         double  sub_norm = 0.0, norm_sum = 0.0;
518         double *d_array_start;
519         for (CeedInt i = 0; i < num_calls; i++) {
520           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
521           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
522           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
523           CeedCallCublas(ceed, cublasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
524           norm_sum += sub_norm * sub_norm;
525         }
526         *norm = sqrt(norm_sum);
527 #endif
528       }
529       break;
530     }
531     case CEED_NORM_MAX: {
532       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
533 #if CUDA_VERSION >= 12000
534         int64_t indx;
535         CeedCallCublas(ceed, cublasIsamax_64(handle, (int64_t)length, (float *)d_array, 1, &indx));
536         CeedScalar normNoAbs;
537         CeedCallCuda(ceed, cudaMemcpy(&normNoAbs, impl->d_array + indx - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
538         *norm = fabs(normNoAbs);
539 #else
540         CeedInt indx;
541         float   sub_max = 0.0, current_max = 0.0;
542         float  *d_array_start;
543         for (CeedInt i = 0; i < num_calls; i++) {
544           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
545           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
546           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
547           CeedCallCublas(ceed, cublasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &indx));
548           CeedCallCuda(ceed, cudaMemcpy(&sub_max, d_array_start + indx - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
549           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
550         }
551         *norm = current_max;
552 #endif
553       } else {
554 #if CUDA_VERSION >= 12000
555         int64_t indx;
556         CeedCallCublas(ceed, cublasIdamax_64(handle, (int64_t)length, (double *)d_array, 1, &indx));
557         CeedScalar normNoAbs;
558         CeedCallCuda(ceed, cudaMemcpy(&normNoAbs, impl->d_array + indx - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
559         *norm = fabs(normNoAbs);
560 #else
561         CeedInt indx;
562         double  sub_max = 0.0, current_max = 0.0;
563         double *d_array_start;
564         for (CeedInt i = 0; i < num_calls; i++) {
565           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
566           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
567           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
568           CeedCallCublas(ceed, cublasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &indx));
569           CeedCallCuda(ceed, cudaMemcpy(&sub_max, d_array_start + indx - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
570           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
571         }
572         *norm = current_max;
573 #endif
574       }
575       break;
576     }
577   }
578   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
579 
580   return CEED_ERROR_SUCCESS;
581 }
582 
583 //------------------------------------------------------------------------------
584 // Take reciprocal of a vector on host
585 //------------------------------------------------------------------------------
586 static int CeedHostReciprocal_Cuda(CeedScalar *h_array, CeedSize length) {
587   for (CeedSize i = 0; i < length; i++) {
588     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
589   }
590   return CEED_ERROR_SUCCESS;
591 }
592 
593 //------------------------------------------------------------------------------
594 // Take reciprocal of a vector on device (impl in .cu file)
595 //------------------------------------------------------------------------------
596 int CeedDeviceReciprocal_Cuda(CeedScalar *d_array, CeedSize length);
597 
598 //------------------------------------------------------------------------------
599 // Take reciprocal of a vector
600 //------------------------------------------------------------------------------
601 static int CeedVectorReciprocal_Cuda(CeedVector vec) {
602   Ceed ceed;
603   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
604   CeedVector_Cuda *impl;
605   CeedCallBackend(CeedVectorGetData(vec, &impl));
606   CeedSize length;
607   CeedCallBackend(CeedVectorGetLength(vec, &length));
608 
609   // Set value for synced device/host array
610   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Cuda(impl->d_array, length));
611   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Cuda(impl->h_array, length));
612 
613   return CEED_ERROR_SUCCESS;
614 }
615 
616 //------------------------------------------------------------------------------
617 // Compute x = alpha x on the host
618 //------------------------------------------------------------------------------
619 static int CeedHostScale_Cuda(CeedScalar *x_array, CeedScalar alpha, CeedSize length) {
620   for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha;
621   return CEED_ERROR_SUCCESS;
622 }
623 
624 //------------------------------------------------------------------------------
625 // Compute x = alpha x on device (impl in .cu file)
626 //------------------------------------------------------------------------------
627 int CeedDeviceScale_Cuda(CeedScalar *x_array, CeedScalar alpha, CeedSize length);
628 
629 //------------------------------------------------------------------------------
630 // Compute x = alpha x
631 //------------------------------------------------------------------------------
632 static int CeedVectorScale_Cuda(CeedVector x, CeedScalar alpha) {
633   Ceed ceed;
634   CeedCallBackend(CeedVectorGetCeed(x, &ceed));
635   CeedVector_Cuda *x_impl;
636   CeedCallBackend(CeedVectorGetData(x, &x_impl));
637   CeedSize length;
638   CeedCallBackend(CeedVectorGetLength(x, &length));
639 
640   // Set value for synced device/host array
641   if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Cuda(x_impl->d_array, alpha, length));
642   if (x_impl->h_array) CeedCallBackend(CeedHostScale_Cuda(x_impl->h_array, alpha, length));
643 
644   return CEED_ERROR_SUCCESS;
645 }
646 
647 //------------------------------------------------------------------------------
648 // Compute y = alpha x + y on the host
649 //------------------------------------------------------------------------------
650 static int CeedHostAXPY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) {
651   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
652   return CEED_ERROR_SUCCESS;
653 }
654 
655 //------------------------------------------------------------------------------
656 // Compute y = alpha x + y on device (impl in .cu file)
657 //------------------------------------------------------------------------------
658 int CeedDeviceAXPY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length);
659 
660 //------------------------------------------------------------------------------
661 // Compute y = alpha x + y
662 //------------------------------------------------------------------------------
663 static int CeedVectorAXPY_Cuda(CeedVector y, CeedScalar alpha, CeedVector x) {
664   Ceed ceed;
665   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
666   CeedVector_Cuda *y_impl, *x_impl;
667   CeedCallBackend(CeedVectorGetData(y, &y_impl));
668   CeedCallBackend(CeedVectorGetData(x, &x_impl));
669   CeedSize length;
670   CeedCallBackend(CeedVectorGetLength(y, &length));
671 
672   // Set value for synced device/host array
673   if (y_impl->d_array) {
674     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
675     CeedCallBackend(CeedDeviceAXPY_Cuda(y_impl->d_array, alpha, x_impl->d_array, length));
676   }
677   if (y_impl->h_array) {
678     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
679     CeedCallBackend(CeedHostAXPY_Cuda(y_impl->h_array, alpha, x_impl->h_array, length));
680   }
681 
682   return CEED_ERROR_SUCCESS;
683 }
684 
685 //------------------------------------------------------------------------------
686 // Compute y = alpha x + beta y on the host
687 //------------------------------------------------------------------------------
688 static int CeedHostAXPBY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) {
689   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i] + beta * y_array[i];
690   return CEED_ERROR_SUCCESS;
691 }
692 
693 //------------------------------------------------------------------------------
694 // Compute y = alpha x + beta y on device (impl in .cu file)
695 //------------------------------------------------------------------------------
696 int CeedDeviceAXPBY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length);
697 
698 //------------------------------------------------------------------------------
699 // Compute y = alpha x + beta y
700 //------------------------------------------------------------------------------
701 static int CeedVectorAXPBY_Cuda(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
702   Ceed ceed;
703   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
704   CeedVector_Cuda *y_impl, *x_impl;
705   CeedCallBackend(CeedVectorGetData(y, &y_impl));
706   CeedCallBackend(CeedVectorGetData(x, &x_impl));
707   CeedSize length;
708   CeedCallBackend(CeedVectorGetLength(y, &length));
709 
710   // Set value for synced device/host array
711   if (y_impl->d_array) {
712     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
713     CeedCallBackend(CeedDeviceAXPBY_Cuda(y_impl->d_array, alpha, beta, x_impl->d_array, length));
714   }
715   if (y_impl->h_array) {
716     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
717     CeedCallBackend(CeedHostAXPBY_Cuda(y_impl->h_array, alpha, beta, x_impl->h_array, length));
718   }
719 
720   return CEED_ERROR_SUCCESS;
721 }
722 
723 //------------------------------------------------------------------------------
724 // Compute the pointwise multiplication w = x .* y on the host
725 //------------------------------------------------------------------------------
726 static int CeedHostPointwiseMult_Cuda(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) {
727   for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
728   return CEED_ERROR_SUCCESS;
729 }
730 
731 //------------------------------------------------------------------------------
732 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
733 //------------------------------------------------------------------------------
734 int CeedDevicePointwiseMult_Cuda(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length);
735 
736 //------------------------------------------------------------------------------
737 // Compute the pointwise multiplication w = x .* y
738 //------------------------------------------------------------------------------
739 static int CeedVectorPointwiseMult_Cuda(CeedVector w, CeedVector x, CeedVector y) {
740   Ceed ceed;
741   CeedCallBackend(CeedVectorGetCeed(w, &ceed));
742   CeedVector_Cuda *w_impl, *x_impl, *y_impl;
743   CeedCallBackend(CeedVectorGetData(w, &w_impl));
744   CeedCallBackend(CeedVectorGetData(x, &x_impl));
745   CeedCallBackend(CeedVectorGetData(y, &y_impl));
746   CeedSize length;
747   CeedCallBackend(CeedVectorGetLength(w, &length));
748 
749   // Set value for synced device/host array
750   if (!w_impl->d_array && !w_impl->h_array) {
751     CeedCallBackend(CeedVectorSetValue(w, 0.0));
752   }
753   if (w_impl->d_array) {
754     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
755     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
756     CeedCallBackend(CeedDevicePointwiseMult_Cuda(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
757   }
758   if (w_impl->h_array) {
759     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
760     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
761     CeedCallBackend(CeedHostPointwiseMult_Cuda(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
762   }
763 
764   return CEED_ERROR_SUCCESS;
765 }
766 
767 //------------------------------------------------------------------------------
768 // Destroy the vector
769 //------------------------------------------------------------------------------
770 static int CeedVectorDestroy_Cuda(const CeedVector vec) {
771   Ceed ceed;
772   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
773   CeedVector_Cuda *impl;
774   CeedCallBackend(CeedVectorGetData(vec, &impl));
775 
776   CeedCallCuda(ceed, cudaFree(impl->d_array_owned));
777   CeedCallBackend(CeedFree(&impl->h_array_owned));
778   CeedCallBackend(CeedFree(&impl));
779 
780   return CEED_ERROR_SUCCESS;
781 }
782 
783 //------------------------------------------------------------------------------
784 // Create a vector of the specified length (does not allocate memory)
785 //------------------------------------------------------------------------------
786 int CeedVectorCreate_Cuda(CeedSize n, CeedVector vec) {
787   CeedVector_Cuda *impl;
788   Ceed             ceed;
789   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
790 
791   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Cuda));
792   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Cuda));
793   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Cuda));
794   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Cuda));
795   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", (int (*)())CeedVectorSetValue_Cuda));
796   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Cuda));
797   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Cuda));
798   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Cuda));
799   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Cuda));
800   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Cuda));
801   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Cuda));
802   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", (int (*)())CeedVectorScale_Cuda));
803   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", (int (*)())CeedVectorAXPY_Cuda));
804   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", (int (*)())CeedVectorAXPBY_Cuda));
805   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Cuda));
806   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Cuda));
807 
808   CeedCallBackend(CeedCalloc(1, &impl));
809   CeedCallBackend(CeedVectorSetData(vec, impl));
810 
811   return CEED_ERROR_SUCCESS;
812 }
813 
814 //------------------------------------------------------------------------------
815