xref: /libCEED/backends/cuda-ref/ceed-cuda-ref-vector.c (revision 2f439227254593e594d6ab1b51afd92f56e9ee04)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <cuda_runtime.h>
11 #include <math.h>
12 #include <stdbool.h>
13 #include <string.h>
14 
15 #include "../cuda/ceed-cuda-common.h"
16 #include "ceed-cuda-ref.h"
17 
18 //------------------------------------------------------------------------------
19 // Check if host/device sync is needed
20 //------------------------------------------------------------------------------
21 static inline int CeedVectorNeedSync_Cuda(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
22   bool             has_valid_array = false;
23   CeedVector_Cuda *impl;
24 
25   CeedCallBackend(CeedVectorGetData(vec, &impl));
26   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
27   switch (mem_type) {
28     case CEED_MEM_HOST:
29       *need_sync = has_valid_array && !impl->h_array;
30       break;
31     case CEED_MEM_DEVICE:
32       *need_sync = has_valid_array && !impl->d_array;
33       break;
34   }
35   return CEED_ERROR_SUCCESS;
36 }
37 
38 //------------------------------------------------------------------------------
39 // Sync host to device
40 //------------------------------------------------------------------------------
41 static inline int CeedVectorSyncH2D_Cuda(const CeedVector vec) {
42   Ceed             ceed;
43   CeedSize         length;
44   size_t           bytes;
45   CeedVector_Cuda *impl;
46 
47   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
48   CeedCallBackend(CeedVectorGetData(vec, &impl));
49 
50   CeedCheck(impl->h_array, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
51 
52   CeedCallBackend(CeedVectorGetLength(vec, &length));
53   bytes = length * sizeof(CeedScalar);
54   if (impl->d_array_borrowed) {
55     impl->d_array = impl->d_array_borrowed;
56   } else if (impl->d_array_owned) {
57     impl->d_array = impl->d_array_owned;
58   } else {
59     CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_array_owned, bytes));
60     impl->d_array = impl->d_array_owned;
61   }
62   CeedCallCuda(ceed, cudaMemcpy(impl->d_array, impl->h_array, bytes, cudaMemcpyHostToDevice));
63   return CEED_ERROR_SUCCESS;
64 }
65 
66 //------------------------------------------------------------------------------
67 // Sync device to host
68 //------------------------------------------------------------------------------
69 static inline int CeedVectorSyncD2H_Cuda(const CeedVector vec) {
70   Ceed             ceed;
71   CeedSize         length;
72   CeedVector_Cuda *impl;
73 
74   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
75   CeedCallBackend(CeedVectorGetData(vec, &impl));
76 
77   CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
78 
79   if (impl->h_array_borrowed) {
80     impl->h_array = impl->h_array_borrowed;
81   } else if (impl->h_array_owned) {
82     impl->h_array = impl->h_array_owned;
83   } else {
84     CeedSize length;
85 
86     CeedCallBackend(CeedVectorGetLength(vec, &length));
87     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
88     impl->h_array = impl->h_array_owned;
89   }
90 
91   CeedCallBackend(CeedVectorGetLength(vec, &length));
92   size_t bytes = length * sizeof(CeedScalar);
93 
94   CeedCallCuda(ceed, cudaMemcpy(impl->h_array, impl->d_array, bytes, cudaMemcpyDeviceToHost));
95   return CEED_ERROR_SUCCESS;
96 }
97 
98 //------------------------------------------------------------------------------
99 // Sync arrays
100 //------------------------------------------------------------------------------
101 static int CeedVectorSyncArray_Cuda(const CeedVector vec, CeedMemType mem_type) {
102   bool need_sync = false;
103 
104   // Check whether device/host sync is needed
105   CeedCallBackend(CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync));
106   if (!need_sync) return CEED_ERROR_SUCCESS;
107 
108   switch (mem_type) {
109     case CEED_MEM_HOST:
110       return CeedVectorSyncD2H_Cuda(vec);
111     case CEED_MEM_DEVICE:
112       return CeedVectorSyncH2D_Cuda(vec);
113   }
114   return CEED_ERROR_UNSUPPORTED;
115 }
116 
117 //------------------------------------------------------------------------------
118 // Set all pointers as invalid
119 //------------------------------------------------------------------------------
120 static inline int CeedVectorSetAllInvalid_Cuda(const CeedVector vec) {
121   CeedVector_Cuda *impl;
122 
123   CeedCallBackend(CeedVectorGetData(vec, &impl));
124   impl->h_array = NULL;
125   impl->d_array = NULL;
126   return CEED_ERROR_SUCCESS;
127 }
128 
129 //------------------------------------------------------------------------------
130 // Check if CeedVector has any valid pointer
131 //------------------------------------------------------------------------------
132 static inline int CeedVectorHasValidArray_Cuda(const CeedVector vec, bool *has_valid_array) {
133   CeedVector_Cuda *impl;
134 
135   CeedCallBackend(CeedVectorGetData(vec, &impl));
136   *has_valid_array = impl->h_array || impl->d_array;
137   return CEED_ERROR_SUCCESS;
138 }
139 
140 //------------------------------------------------------------------------------
141 // Check if has array of given type
142 //------------------------------------------------------------------------------
143 static inline int CeedVectorHasArrayOfType_Cuda(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
144   CeedVector_Cuda *impl;
145 
146   CeedCallBackend(CeedVectorGetData(vec, &impl));
147   switch (mem_type) {
148     case CEED_MEM_HOST:
149       *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned;
150       break;
151     case CEED_MEM_DEVICE:
152       *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned;
153       break;
154   }
155   return CEED_ERROR_SUCCESS;
156 }
157 
158 //------------------------------------------------------------------------------
159 // Check if has borrowed array of given type
160 //------------------------------------------------------------------------------
161 static inline int CeedVectorHasBorrowedArrayOfType_Cuda(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
162   CeedVector_Cuda *impl;
163 
164   CeedCallBackend(CeedVectorGetData(vec, &impl));
165   switch (mem_type) {
166     case CEED_MEM_HOST:
167       *has_borrowed_array_of_type = impl->h_array_borrowed;
168       break;
169     case CEED_MEM_DEVICE:
170       *has_borrowed_array_of_type = impl->d_array_borrowed;
171       break;
172   }
173   return CEED_ERROR_SUCCESS;
174 }
175 
176 //------------------------------------------------------------------------------
177 // Set array from host
178 //------------------------------------------------------------------------------
179 static int CeedVectorSetArrayHost_Cuda(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
180   CeedVector_Cuda *impl;
181 
182   CeedCallBackend(CeedVectorGetData(vec, &impl));
183   switch (copy_mode) {
184     case CEED_COPY_VALUES: {
185       if (!impl->h_array_owned) {
186         CeedSize length;
187 
188         CeedCallBackend(CeedVectorGetLength(vec, &length));
189         CeedCallBackend(CeedMalloc(length, &impl->h_array_owned));
190       }
191       impl->h_array_borrowed = NULL;
192       impl->h_array          = impl->h_array_owned;
193       if (array) {
194         CeedSize length;
195         size_t   bytes;
196 
197         CeedCallBackend(CeedVectorGetLength(vec, &length));
198         bytes = length * sizeof(CeedScalar);
199         memcpy(impl->h_array, array, bytes);
200       }
201     } break;
202     case CEED_OWN_POINTER:
203       CeedCallBackend(CeedFree(&impl->h_array_owned));
204       impl->h_array_owned    = array;
205       impl->h_array_borrowed = NULL;
206       impl->h_array          = array;
207       break;
208     case CEED_USE_POINTER:
209       CeedCallBackend(CeedFree(&impl->h_array_owned));
210       impl->h_array_borrowed = array;
211       impl->h_array          = array;
212       break;
213   }
214   return CEED_ERROR_SUCCESS;
215 }
216 
217 //------------------------------------------------------------------------------
218 // Set array from device
219 //------------------------------------------------------------------------------
220 static int CeedVectorSetArrayDevice_Cuda(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
221   Ceed             ceed;
222   CeedVector_Cuda *impl;
223 
224   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
225   CeedCallBackend(CeedVectorGetData(vec, &impl));
226   switch (copy_mode) {
227     case CEED_COPY_VALUES: {
228       CeedSize length;
229       size_t   bytes;
230 
231       CeedCallBackend(CeedVectorGetLength(vec, &length));
232       bytes = length * sizeof(CeedScalar);
233       if (!impl->d_array_owned) {
234         CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_array_owned, bytes));
235       }
236       impl->d_array_borrowed = NULL;
237       impl->d_array          = impl->d_array_owned;
238       if (array) CeedCallCuda(ceed, cudaMemcpy(impl->d_array, array, bytes, cudaMemcpyDeviceToDevice));
239     } break;
240     case CEED_OWN_POINTER:
241       CeedCallCuda(ceed, cudaFree(impl->d_array_owned));
242       impl->d_array_owned    = array;
243       impl->d_array_borrowed = NULL;
244       impl->d_array          = array;
245       break;
246     case CEED_USE_POINTER:
247       CeedCallCuda(ceed, cudaFree(impl->d_array_owned));
248       impl->d_array_owned    = NULL;
249       impl->d_array_borrowed = array;
250       impl->d_array          = array;
251       break;
252   }
253   return CEED_ERROR_SUCCESS;
254 }
255 
256 //------------------------------------------------------------------------------
257 // Set the array used by a vector,
258 //   freeing any previously allocated array if applicable
259 //------------------------------------------------------------------------------
260 static int CeedVectorSetArray_Cuda(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
261   Ceed             ceed;
262   CeedVector_Cuda *impl;
263 
264   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
265   CeedCallBackend(CeedVectorGetData(vec, &impl));
266   CeedCallBackend(CeedVectorSetAllInvalid_Cuda(vec));
267   switch (mem_type) {
268     case CEED_MEM_HOST:
269       return CeedVectorSetArrayHost_Cuda(vec, copy_mode, array);
270     case CEED_MEM_DEVICE:
271       return CeedVectorSetArrayDevice_Cuda(vec, copy_mode, array);
272   }
273   return CEED_ERROR_UNSUPPORTED;
274 }
275 
276 //------------------------------------------------------------------------------
277 // Set host array to value
278 //------------------------------------------------------------------------------
279 static int CeedHostSetValue_Cuda(CeedScalar *h_array, CeedSize length, CeedScalar val) {
280   for (CeedSize i = 0; i < length; i++) h_array[i] = val;
281   return CEED_ERROR_SUCCESS;
282 }
283 
284 //------------------------------------------------------------------------------
285 // Set device array to value (impl in .cu file)
286 //------------------------------------------------------------------------------
287 int CeedDeviceSetValue_Cuda(CeedScalar *d_array, CeedSize length, CeedScalar val);
288 
289 //------------------------------------------------------------------------------
290 // Set a vector to a value
291 //------------------------------------------------------------------------------
292 static int CeedVectorSetValue_Cuda(CeedVector vec, CeedScalar val) {
293   Ceed             ceed;
294   CeedSize         length;
295   CeedVector_Cuda *impl;
296 
297   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
298   CeedCallBackend(CeedVectorGetData(vec, &impl));
299   CeedCallBackend(CeedVectorGetLength(vec, &length));
300   // Set value for synced device/host array
301   if (!impl->d_array && !impl->h_array) {
302     if (impl->d_array_borrowed) {
303       impl->d_array = impl->d_array_borrowed;
304     } else if (impl->h_array_borrowed) {
305       impl->h_array = impl->h_array_borrowed;
306     } else if (impl->d_array_owned) {
307       impl->d_array = impl->d_array_owned;
308     } else if (impl->h_array_owned) {
309       impl->h_array = impl->h_array_owned;
310     } else {
311       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
312     }
313   }
314   if (impl->d_array) {
315     CeedCallBackend(CeedDeviceSetValue_Cuda(impl->d_array, length, val));
316     impl->h_array = NULL;
317   }
318   if (impl->h_array) {
319     CeedCallBackend(CeedHostSetValue_Cuda(impl->h_array, length, val));
320     impl->d_array = NULL;
321   }
322   return CEED_ERROR_SUCCESS;
323 }
324 
325 //------------------------------------------------------------------------------
326 // Vector Take Array
327 //------------------------------------------------------------------------------
328 static int CeedVectorTakeArray_Cuda(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
329   Ceed             ceed;
330   CeedVector_Cuda *impl;
331 
332   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
333   CeedCallBackend(CeedVectorGetData(vec, &impl));
334   // Sync array to requested mem_type
335   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
336   // Update pointer
337   switch (mem_type) {
338     case CEED_MEM_HOST:
339       (*array)               = impl->h_array_borrowed;
340       impl->h_array_borrowed = NULL;
341       impl->h_array          = NULL;
342       break;
343     case CEED_MEM_DEVICE:
344       (*array)               = impl->d_array_borrowed;
345       impl->d_array_borrowed = NULL;
346       impl->d_array          = NULL;
347       break;
348   }
349   return CEED_ERROR_SUCCESS;
350 }
351 
352 //------------------------------------------------------------------------------
353 // Core logic for array syncronization for GetArray.
354 //   If a different memory type is most up to date, this will perform a copy
355 //------------------------------------------------------------------------------
356 static int CeedVectorGetArrayCore_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
357   Ceed             ceed;
358   CeedVector_Cuda *impl;
359 
360   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
361   CeedCallBackend(CeedVectorGetData(vec, &impl));
362   // Sync array to requested mem_type
363   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
364   // Update pointer
365   switch (mem_type) {
366     case CEED_MEM_HOST:
367       *array = impl->h_array;
368       break;
369     case CEED_MEM_DEVICE:
370       *array = impl->d_array;
371       break;
372   }
373   return CEED_ERROR_SUCCESS;
374 }
375 
376 //------------------------------------------------------------------------------
377 // Get read-only access to a vector via the specified mem_type
378 //------------------------------------------------------------------------------
379 static int CeedVectorGetArrayRead_Cuda(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
380   return CeedVectorGetArrayCore_Cuda(vec, mem_type, (CeedScalar **)array);
381 }
382 
383 //------------------------------------------------------------------------------
384 // Get read/write access to a vector via the specified mem_type
385 //------------------------------------------------------------------------------
386 static int CeedVectorGetArray_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
387   CeedVector_Cuda *impl;
388 
389   CeedCallBackend(CeedVectorGetData(vec, &impl));
390   CeedCallBackend(CeedVectorGetArrayCore_Cuda(vec, mem_type, array));
391   CeedCallBackend(CeedVectorSetAllInvalid_Cuda(vec));
392   switch (mem_type) {
393     case CEED_MEM_HOST:
394       impl->h_array = *array;
395       break;
396     case CEED_MEM_DEVICE:
397       impl->d_array = *array;
398       break;
399   }
400   return CEED_ERROR_SUCCESS;
401 }
402 
403 //------------------------------------------------------------------------------
404 // Get write access to a vector via the specified mem_type
405 //------------------------------------------------------------------------------
406 static int CeedVectorGetArrayWrite_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
407   bool             has_array_of_type = true;
408   CeedVector_Cuda *impl;
409 
410   CeedCallBackend(CeedVectorGetData(vec, &impl));
411   CeedCallBackend(CeedVectorHasArrayOfType_Cuda(vec, mem_type, &has_array_of_type));
412   if (!has_array_of_type) {
413     // Allocate if array is not yet allocated
414     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
415   } else {
416     // Select dirty array
417     switch (mem_type) {
418       case CEED_MEM_HOST:
419         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
420         else impl->h_array = impl->h_array_owned;
421         break;
422       case CEED_MEM_DEVICE:
423         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
424         else impl->d_array = impl->d_array_owned;
425     }
426   }
427   return CeedVectorGetArray_Cuda(vec, mem_type, array);
428 }
429 
430 //------------------------------------------------------------------------------
431 // Get the norm of a CeedVector
432 //------------------------------------------------------------------------------
433 static int CeedVectorNorm_Cuda(CeedVector vec, CeedNormType type, CeedScalar *norm) {
434   Ceed     ceed;
435   CeedSize length;
436 #if CUDA_VERSION < 12000
437   CeedSize num_calls;
438 #endif
439   const CeedScalar *d_array;
440   CeedVector_Cuda  *impl;
441   cublasHandle_t    handle;
442 
443   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
444   CeedCallBackend(CeedVectorGetData(vec, &impl));
445   CeedCallBackend(CeedVectorGetLength(vec, &length));
446   CeedCallBackend(CeedGetCublasHandle_Cuda(ceed, &handle));
447 
448 #if CUDA_VERSION < 12000
449   // With CUDA 12, we can use the 64-bit integer interface. Prior to that,
450   // we need to check if the vector is too long to handle with int32,
451   // and if so, divide it into subsections for repeated cuBLAS calls.
452   num_calls = length / INT_MAX;
453   if (length % INT_MAX > 0) num_calls += 1;
454 #endif
455 
456   // Compute norm
457   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
458   switch (type) {
459     case CEED_NORM_1: {
460       *norm = 0.0;
461       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
462 #if CUDA_VERSION >= 12000  // We have CUDA 12, and can use 64-bit integers
463         CeedCallCublas(ceed, cublasSasum_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
464 #else
465         float  sub_norm = 0.0;
466         float *d_array_start;
467 
468         for (CeedInt i = 0; i < num_calls; i++) {
469           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
470           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
471           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
472 
473           CeedCallCublas(ceed, cublasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
474           *norm += sub_norm;
475         }
476 #endif
477       } else {
478 #if CUDA_VERSION >= 12000
479         CeedCallCublas(ceed, cublasDasum_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
480 #else
481         double  sub_norm = 0.0;
482         double *d_array_start;
483 
484         for (CeedInt i = 0; i < num_calls; i++) {
485           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
486           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
487           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
488 
489           CeedCallCublas(ceed, cublasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
490           *norm += sub_norm;
491         }
492 #endif
493       }
494       break;
495     }
496     case CEED_NORM_2: {
497       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
498 #if CUDA_VERSION >= 12000
499         CeedCallCublas(ceed, cublasSnrm2_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
500 #else
501         float  sub_norm = 0.0, norm_sum = 0.0;
502         float *d_array_start;
503 
504         for (CeedInt i = 0; i < num_calls; i++) {
505           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
506           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
507           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
508 
509           CeedCallCublas(ceed, cublasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
510           norm_sum += sub_norm * sub_norm;
511         }
512         *norm            = sqrt(norm_sum);
513 #endif
514       } else {
515 #if CUDA_VERSION >= 12000
516         CeedCallCublas(ceed, cublasDnrm2_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
517 #else
518         double  sub_norm = 0.0, norm_sum = 0.0;
519         double *d_array_start;
520 
521         for (CeedInt i = 0; i < num_calls; i++) {
522           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
523           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
524           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
525 
526           CeedCallCublas(ceed, cublasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
527           norm_sum += sub_norm * sub_norm;
528         }
529         *norm = sqrt(norm_sum);
530 #endif
531       }
532       break;
533     }
534     case CEED_NORM_MAX: {
535       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
536 #if CUDA_VERSION >= 12000
537         int64_t    index;
538         CeedScalar norm_no_abs;
539 
540         CeedCallCublas(ceed, cublasIsamax_64(handle, (int64_t)length, (float *)d_array, 1, &index));
541         CeedCallCuda(ceed, cudaMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
542         *norm = fabs(norm_no_abs);
543 #else
544         CeedInt index;
545         float   sub_max = 0.0, current_max = 0.0;
546         float  *d_array_start;
547 
548         for (CeedInt i = 0; i < num_calls; i++) {
549           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
550           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
551           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
552 
553           CeedCallCublas(ceed, cublasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &index));
554           CeedCallCuda(ceed, cudaMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
555           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
556         }
557         *norm = current_max;
558 #endif
559       } else {
560 #if CUDA_VERSION >= 12000
561         int64_t    index;
562         CeedScalar norm_no_abs;
563 
564         CeedCallCublas(ceed, cublasIdamax_64(handle, (int64_t)length, (double *)d_array, 1, &index));
565         CeedCallCuda(ceed, cudaMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
566         *norm = fabs(norm_no_abs);
567 #else
568         CeedInt index;
569         double  sub_max = 0.0, current_max = 0.0;
570         double *d_array_start;
571 
572         for (CeedInt i = 0; i < num_calls; i++) {
573           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
574           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
575           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
576 
577           CeedCallCublas(ceed, cublasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &index));
578           CeedCallCuda(ceed, cudaMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
579           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
580         }
581         *norm = current_max;
582 #endif
583       }
584       break;
585     }
586   }
587   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
588   return CEED_ERROR_SUCCESS;
589 }
590 
591 //------------------------------------------------------------------------------
592 // Take reciprocal of a vector on host
593 //------------------------------------------------------------------------------
594 static int CeedHostReciprocal_Cuda(CeedScalar *h_array, CeedSize length) {
595   for (CeedSize i = 0; i < length; i++) {
596     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
597   }
598   return CEED_ERROR_SUCCESS;
599 }
600 
601 //------------------------------------------------------------------------------
602 // Take reciprocal of a vector on device (impl in .cu file)
603 //------------------------------------------------------------------------------
604 int CeedDeviceReciprocal_Cuda(CeedScalar *d_array, CeedSize length);
605 
606 //------------------------------------------------------------------------------
607 // Take reciprocal of a vector
608 //------------------------------------------------------------------------------
609 static int CeedVectorReciprocal_Cuda(CeedVector vec) {
610   Ceed             ceed;
611   CeedSize         length;
612   CeedVector_Cuda *impl;
613 
614   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
615   CeedCallBackend(CeedVectorGetData(vec, &impl));
616   CeedCallBackend(CeedVectorGetLength(vec, &length));
617   // Set value for synced device/host array
618   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Cuda(impl->d_array, length));
619   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Cuda(impl->h_array, length));
620   return CEED_ERROR_SUCCESS;
621 }
622 
623 //------------------------------------------------------------------------------
624 // Compute x = alpha x on the host
625 //------------------------------------------------------------------------------
626 static int CeedHostScale_Cuda(CeedScalar *x_array, CeedScalar alpha, CeedSize length) {
627   for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha;
628   return CEED_ERROR_SUCCESS;
629 }
630 
631 //------------------------------------------------------------------------------
632 // Compute x = alpha x on device (impl in .cu file)
633 //------------------------------------------------------------------------------
634 int CeedDeviceScale_Cuda(CeedScalar *x_array, CeedScalar alpha, CeedSize length);
635 
636 //------------------------------------------------------------------------------
637 // Compute x = alpha x
638 //------------------------------------------------------------------------------
639 static int CeedVectorScale_Cuda(CeedVector x, CeedScalar alpha) {
640   Ceed             ceed;
641   CeedSize         length;
642   CeedVector_Cuda *x_impl;
643 
644   CeedCallBackend(CeedVectorGetCeed(x, &ceed));
645   CeedCallBackend(CeedVectorGetData(x, &x_impl));
646   CeedCallBackend(CeedVectorGetLength(x, &length));
647   // Set value for synced device/host array
648   if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Cuda(x_impl->d_array, alpha, length));
649   if (x_impl->h_array) CeedCallBackend(CeedHostScale_Cuda(x_impl->h_array, alpha, length));
650   return CEED_ERROR_SUCCESS;
651 }
652 
653 //------------------------------------------------------------------------------
654 // Compute y = alpha x + y on the host
655 //------------------------------------------------------------------------------
656 static int CeedHostAXPY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) {
657   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
658   return CEED_ERROR_SUCCESS;
659 }
660 
661 //------------------------------------------------------------------------------
662 // Compute y = alpha x + y on device (impl in .cu file)
663 //------------------------------------------------------------------------------
664 int CeedDeviceAXPY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length);
665 
666 //------------------------------------------------------------------------------
667 // Compute y = alpha x + y
668 //------------------------------------------------------------------------------
669 static int CeedVectorAXPY_Cuda(CeedVector y, CeedScalar alpha, CeedVector x) {
670   Ceed             ceed;
671   CeedSize         length;
672   CeedVector_Cuda *y_impl, *x_impl;
673 
674   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
675   CeedCallBackend(CeedVectorGetData(y, &y_impl));
676   CeedCallBackend(CeedVectorGetData(x, &x_impl));
677   CeedCallBackend(CeedVectorGetLength(y, &length));
678   // Set value for synced device/host array
679   if (y_impl->d_array) {
680     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
681     CeedCallBackend(CeedDeviceAXPY_Cuda(y_impl->d_array, alpha, x_impl->d_array, length));
682   }
683   if (y_impl->h_array) {
684     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
685     CeedCallBackend(CeedHostAXPY_Cuda(y_impl->h_array, alpha, x_impl->h_array, length));
686   }
687   return CEED_ERROR_SUCCESS;
688 }
689 
690 //------------------------------------------------------------------------------
691 // Compute y = alpha x + beta y on the host
692 //------------------------------------------------------------------------------
693 static int CeedHostAXPBY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) {
694   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i] + beta * y_array[i];
695   return CEED_ERROR_SUCCESS;
696 }
697 
698 //------------------------------------------------------------------------------
699 // Compute y = alpha x + beta y on device (impl in .cu file)
700 //------------------------------------------------------------------------------
701 int CeedDeviceAXPBY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length);
702 
703 //------------------------------------------------------------------------------
704 // Compute y = alpha x + beta y
705 //------------------------------------------------------------------------------
706 static int CeedVectorAXPBY_Cuda(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
707   Ceed             ceed;
708   CeedSize         length;
709   CeedVector_Cuda *y_impl, *x_impl;
710 
711   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
712   CeedCallBackend(CeedVectorGetData(y, &y_impl));
713   CeedCallBackend(CeedVectorGetData(x, &x_impl));
714   CeedCallBackend(CeedVectorGetLength(y, &length));
715   // Set value for synced device/host array
716   if (y_impl->d_array) {
717     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
718     CeedCallBackend(CeedDeviceAXPBY_Cuda(y_impl->d_array, alpha, beta, x_impl->d_array, length));
719   }
720   if (y_impl->h_array) {
721     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
722     CeedCallBackend(CeedHostAXPBY_Cuda(y_impl->h_array, alpha, beta, x_impl->h_array, length));
723   }
724   return CEED_ERROR_SUCCESS;
725 }
726 
727 //------------------------------------------------------------------------------
728 // Compute the pointwise multiplication w = x .* y on the host
729 //------------------------------------------------------------------------------
730 static int CeedHostPointwiseMult_Cuda(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) {
731   for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
732   return CEED_ERROR_SUCCESS;
733 }
734 
735 //------------------------------------------------------------------------------
736 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
737 //------------------------------------------------------------------------------
738 int CeedDevicePointwiseMult_Cuda(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length);
739 
740 //------------------------------------------------------------------------------
741 // Compute the pointwise multiplication w = x .* y
742 //------------------------------------------------------------------------------
743 static int CeedVectorPointwiseMult_Cuda(CeedVector w, CeedVector x, CeedVector y) {
744   Ceed             ceed;
745   CeedSize         length;
746   CeedVector_Cuda *w_impl, *x_impl, *y_impl;
747 
748   CeedCallBackend(CeedVectorGetCeed(w, &ceed));
749   CeedCallBackend(CeedVectorGetData(w, &w_impl));
750   CeedCallBackend(CeedVectorGetData(x, &x_impl));
751   CeedCallBackend(CeedVectorGetData(y, &y_impl));
752   CeedCallBackend(CeedVectorGetLength(w, &length));
753   // Set value for synced device/host array
754   if (!w_impl->d_array && !w_impl->h_array) {
755     CeedCallBackend(CeedVectorSetValue(w, 0.0));
756   }
757   if (w_impl->d_array) {
758     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
759     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
760     CeedCallBackend(CeedDevicePointwiseMult_Cuda(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
761   }
762   if (w_impl->h_array) {
763     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
764     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
765     CeedCallBackend(CeedHostPointwiseMult_Cuda(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
766   }
767   return CEED_ERROR_SUCCESS;
768 }
769 
770 //------------------------------------------------------------------------------
771 // Destroy the vector
772 //------------------------------------------------------------------------------
773 static int CeedVectorDestroy_Cuda(const CeedVector vec) {
774   Ceed             ceed;
775   CeedVector_Cuda *impl;
776 
777   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
778   CeedCallBackend(CeedVectorGetData(vec, &impl));
779   CeedCallCuda(ceed, cudaFree(impl->d_array_owned));
780   CeedCallBackend(CeedFree(&impl->h_array_owned));
781   CeedCallBackend(CeedFree(&impl));
782   return CEED_ERROR_SUCCESS;
783 }
784 
785 //------------------------------------------------------------------------------
786 // Create a vector of the specified length (does not allocate memory)
787 //------------------------------------------------------------------------------
788 int CeedVectorCreate_Cuda(CeedSize n, CeedVector vec) {
789   CeedVector_Cuda *impl;
790   Ceed             ceed;
791 
792   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
793   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Cuda));
794   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Cuda));
795   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Cuda));
796   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Cuda));
797   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", (int (*)())CeedVectorSetValue_Cuda));
798   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Cuda));
799   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Cuda));
800   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Cuda));
801   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Cuda));
802   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Cuda));
803   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Cuda));
804   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", (int (*)())CeedVectorScale_Cuda));
805   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", (int (*)())CeedVectorAXPY_Cuda));
806   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", (int (*)())CeedVectorAXPBY_Cuda));
807   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Cuda));
808   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Cuda));
809   CeedCallBackend(CeedCalloc(1, &impl));
810   CeedCallBackend(CeedVectorSetData(vec, impl));
811   return CEED_ERROR_SUCCESS;
812 }
813 
814 //------------------------------------------------------------------------------
815