xref: /libCEED/backends/cuda-ref/ceed-cuda-ref-vector.c (revision 1c66c397a67401e1a222857807e6e5b7c45b88c0)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <cuda_runtime.h>
11 #include <math.h>
12 #include <stdbool.h>
13 #include <string.h>
14 
15 #include "../cuda/ceed-cuda-common.h"
16 #include "ceed-cuda-ref.h"
17 
18 //------------------------------------------------------------------------------
19 // Check if host/device sync is needed
20 //------------------------------------------------------------------------------
21 static inline int CeedVectorNeedSync_Cuda(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
22   CeedVector_Cuda *impl;
23   CeedCallBackend(CeedVectorGetData(vec, &impl));
24 
25   bool has_valid_array = false;
26   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
27   switch (mem_type) {
28     case CEED_MEM_HOST:
29       *need_sync = has_valid_array && !impl->h_array;
30       break;
31     case CEED_MEM_DEVICE:
32       *need_sync = has_valid_array && !impl->d_array;
33       break;
34   }
35 
36   return CEED_ERROR_SUCCESS;
37 }
38 
39 //------------------------------------------------------------------------------
40 // Sync host to device
41 //------------------------------------------------------------------------------
42 static inline int CeedVectorSyncH2D_Cuda(const CeedVector vec) {
43   Ceed ceed;
44   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
45   CeedVector_Cuda *impl;
46   CeedCallBackend(CeedVectorGetData(vec, &impl));
47 
48   CeedCheck(impl->h_array, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
49 
50   CeedSize length;
51   CeedCallBackend(CeedVectorGetLength(vec, &length));
52   size_t bytes = length * sizeof(CeedScalar);
53 
54   if (impl->d_array_borrowed) {
55     impl->d_array = impl->d_array_borrowed;
56   } else if (impl->d_array_owned) {
57     impl->d_array = impl->d_array_owned;
58   } else {
59     CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_array_owned, bytes));
60     impl->d_array = impl->d_array_owned;
61   }
62 
63   CeedCallCuda(ceed, cudaMemcpy(impl->d_array, impl->h_array, bytes, cudaMemcpyHostToDevice));
64 
65   return CEED_ERROR_SUCCESS;
66 }
67 
68 //------------------------------------------------------------------------------
69 // Sync device to host
70 //------------------------------------------------------------------------------
71 static inline int CeedVectorSyncD2H_Cuda(const CeedVector vec) {
72   Ceed ceed;
73   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
74   CeedVector_Cuda *impl;
75   CeedCallBackend(CeedVectorGetData(vec, &impl));
76 
77   CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
78 
79   if (impl->h_array_borrowed) {
80     impl->h_array = impl->h_array_borrowed;
81   } else if (impl->h_array_owned) {
82     impl->h_array = impl->h_array_owned;
83   } else {
84     CeedSize length;
85     CeedCallBackend(CeedVectorGetLength(vec, &length));
86     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
87     impl->h_array = impl->h_array_owned;
88   }
89 
90   CeedSize length;
91   CeedCallBackend(CeedVectorGetLength(vec, &length));
92   size_t bytes = length * sizeof(CeedScalar);
93   CeedCallCuda(ceed, cudaMemcpy(impl->h_array, impl->d_array, bytes, cudaMemcpyDeviceToHost));
94 
95   return CEED_ERROR_SUCCESS;
96 }
97 
98 //------------------------------------------------------------------------------
99 // Sync arrays
100 //------------------------------------------------------------------------------
101 static int CeedVectorSyncArray_Cuda(const CeedVector vec, CeedMemType mem_type) {
102   // Check whether device/host sync is needed
103   bool need_sync = false;
104   CeedCallBackend(CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync));
105   if (!need_sync) return CEED_ERROR_SUCCESS;
106 
107   switch (mem_type) {
108     case CEED_MEM_HOST:
109       return CeedVectorSyncD2H_Cuda(vec);
110     case CEED_MEM_DEVICE:
111       return CeedVectorSyncH2D_Cuda(vec);
112   }
113   return CEED_ERROR_UNSUPPORTED;
114 }
115 
116 //------------------------------------------------------------------------------
117 // Set all pointers as invalid
118 //------------------------------------------------------------------------------
119 static inline int CeedVectorSetAllInvalid_Cuda(const CeedVector vec) {
120   CeedVector_Cuda *impl;
121   CeedCallBackend(CeedVectorGetData(vec, &impl));
122 
123   impl->h_array = NULL;
124   impl->d_array = NULL;
125 
126   return CEED_ERROR_SUCCESS;
127 }
128 
129 //------------------------------------------------------------------------------
130 // Check if CeedVector has any valid pointer
131 //------------------------------------------------------------------------------
132 static inline int CeedVectorHasValidArray_Cuda(const CeedVector vec, bool *has_valid_array) {
133   CeedVector_Cuda *impl;
134   CeedCallBackend(CeedVectorGetData(vec, &impl));
135 
136   *has_valid_array = impl->h_array || impl->d_array;
137 
138   return CEED_ERROR_SUCCESS;
139 }
140 
141 //------------------------------------------------------------------------------
142 // Check if has array of given type
143 //------------------------------------------------------------------------------
144 static inline int CeedVectorHasArrayOfType_Cuda(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
145   CeedVector_Cuda *impl;
146   CeedCallBackend(CeedVectorGetData(vec, &impl));
147 
148   switch (mem_type) {
149     case CEED_MEM_HOST:
150       *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned;
151       break;
152     case CEED_MEM_DEVICE:
153       *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned;
154       break;
155   }
156 
157   return CEED_ERROR_SUCCESS;
158 }
159 
160 //------------------------------------------------------------------------------
161 // Check if has borrowed array of given type
162 //------------------------------------------------------------------------------
163 static inline int CeedVectorHasBorrowedArrayOfType_Cuda(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
164   CeedVector_Cuda *impl;
165   CeedCallBackend(CeedVectorGetData(vec, &impl));
166 
167   switch (mem_type) {
168     case CEED_MEM_HOST:
169       *has_borrowed_array_of_type = impl->h_array_borrowed;
170       break;
171     case CEED_MEM_DEVICE:
172       *has_borrowed_array_of_type = impl->d_array_borrowed;
173       break;
174   }
175 
176   return CEED_ERROR_SUCCESS;
177 }
178 
179 //------------------------------------------------------------------------------
180 // Set array from host
181 //------------------------------------------------------------------------------
182 static int CeedVectorSetArrayHost_Cuda(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
183   CeedVector_Cuda *impl;
184   CeedCallBackend(CeedVectorGetData(vec, &impl));
185 
186   switch (copy_mode) {
187     case CEED_COPY_VALUES: {
188       CeedSize length;
189       if (!impl->h_array_owned) {
190         CeedCallBackend(CeedVectorGetLength(vec, &length));
191         CeedCallBackend(CeedMalloc(length, &impl->h_array_owned));
192       }
193       impl->h_array_borrowed = NULL;
194       impl->h_array          = impl->h_array_owned;
195       if (array) {
196         CeedSize length;
197         CeedCallBackend(CeedVectorGetLength(vec, &length));
198         size_t bytes = length * sizeof(CeedScalar);
199         memcpy(impl->h_array, array, bytes);
200       }
201     } break;
202     case CEED_OWN_POINTER:
203       CeedCallBackend(CeedFree(&impl->h_array_owned));
204       impl->h_array_owned    = array;
205       impl->h_array_borrowed = NULL;
206       impl->h_array          = array;
207       break;
208     case CEED_USE_POINTER:
209       CeedCallBackend(CeedFree(&impl->h_array_owned));
210       impl->h_array_borrowed = array;
211       impl->h_array          = array;
212       break;
213   }
214 
215   return CEED_ERROR_SUCCESS;
216 }
217 
218 //------------------------------------------------------------------------------
219 // Set array from device
220 //------------------------------------------------------------------------------
221 static int CeedVectorSetArrayDevice_Cuda(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
222   Ceed ceed;
223   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
224   CeedVector_Cuda *impl;
225   CeedCallBackend(CeedVectorGetData(vec, &impl));
226 
227   switch (copy_mode) {
228     case CEED_COPY_VALUES: {
229       CeedSize length;
230       CeedCallBackend(CeedVectorGetLength(vec, &length));
231       size_t bytes = length * sizeof(CeedScalar);
232       if (!impl->d_array_owned) {
233         CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_array_owned, bytes));
234         impl->d_array = impl->d_array_owned;
235       }
236       impl->d_array_borrowed = NULL;
237       if (array) CeedCallCuda(ceed, cudaMemcpy(impl->d_array, array, bytes, cudaMemcpyDeviceToDevice));
238     } break;
239     case CEED_OWN_POINTER:
240       CeedCallCuda(ceed, cudaFree(impl->d_array_owned));
241       impl->d_array_owned    = array;
242       impl->d_array_borrowed = NULL;
243       impl->d_array          = array;
244       break;
245     case CEED_USE_POINTER:
246       CeedCallCuda(ceed, cudaFree(impl->d_array_owned));
247       impl->d_array_owned    = NULL;
248       impl->d_array_borrowed = array;
249       impl->d_array          = array;
250       break;
251   }
252 
253   return CEED_ERROR_SUCCESS;
254 }
255 
256 //------------------------------------------------------------------------------
257 // Set the array used by a vector,
258 //   freeing any previously allocated array if applicable
259 //------------------------------------------------------------------------------
260 static int CeedVectorSetArray_Cuda(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
261   Ceed ceed;
262   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
263   CeedVector_Cuda *impl;
264   CeedCallBackend(CeedVectorGetData(vec, &impl));
265 
266   CeedCallBackend(CeedVectorSetAllInvalid_Cuda(vec));
267   switch (mem_type) {
268     case CEED_MEM_HOST:
269       return CeedVectorSetArrayHost_Cuda(vec, copy_mode, array);
270     case CEED_MEM_DEVICE:
271       return CeedVectorSetArrayDevice_Cuda(vec, copy_mode, array);
272   }
273 
274   return CEED_ERROR_UNSUPPORTED;
275 }
276 
277 //------------------------------------------------------------------------------
278 // Set host array to value
279 //------------------------------------------------------------------------------
280 static int CeedHostSetValue_Cuda(CeedScalar *h_array, CeedSize length, CeedScalar val) {
281   for (CeedSize i = 0; i < length; i++) h_array[i] = val;
282   return CEED_ERROR_SUCCESS;
283 }
284 
285 //------------------------------------------------------------------------------
286 // Set device array to value (impl in .cu file)
287 //------------------------------------------------------------------------------
288 int CeedDeviceSetValue_Cuda(CeedScalar *d_array, CeedSize length, CeedScalar val);
289 
290 //------------------------------------------------------------------------------
291 // Set a vector to a value
292 //------------------------------------------------------------------------------
293 static int CeedVectorSetValue_Cuda(CeedVector vec, CeedScalar val) {
294   Ceed ceed;
295   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
296   CeedVector_Cuda *impl;
297   CeedCallBackend(CeedVectorGetData(vec, &impl));
298   CeedSize length;
299   CeedCallBackend(CeedVectorGetLength(vec, &length));
300 
301   // Set value for synced device/host array
302   if (!impl->d_array && !impl->h_array) {
303     if (impl->d_array_borrowed) {
304       impl->d_array = impl->d_array_borrowed;
305     } else if (impl->h_array_borrowed) {
306       impl->h_array = impl->h_array_borrowed;
307     } else if (impl->d_array_owned) {
308       impl->d_array = impl->d_array_owned;
309     } else if (impl->h_array_owned) {
310       impl->h_array = impl->h_array_owned;
311     } else {
312       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
313     }
314   }
315   if (impl->d_array) {
316     CeedCallBackend(CeedDeviceSetValue_Cuda(impl->d_array, length, val));
317     impl->h_array = NULL;
318   }
319   if (impl->h_array) {
320     CeedCallBackend(CeedHostSetValue_Cuda(impl->h_array, length, val));
321     impl->d_array = NULL;
322   }
323 
324   return CEED_ERROR_SUCCESS;
325 }
326 
327 //------------------------------------------------------------------------------
328 // Vector Take Array
329 //------------------------------------------------------------------------------
330 static int CeedVectorTakeArray_Cuda(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
331   Ceed ceed;
332   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
333   CeedVector_Cuda *impl;
334   CeedCallBackend(CeedVectorGetData(vec, &impl));
335 
336   // Sync array to requested mem_type
337   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
338 
339   // Update pointer
340   switch (mem_type) {
341     case CEED_MEM_HOST:
342       (*array)               = impl->h_array_borrowed;
343       impl->h_array_borrowed = NULL;
344       impl->h_array          = NULL;
345       break;
346     case CEED_MEM_DEVICE:
347       (*array)               = impl->d_array_borrowed;
348       impl->d_array_borrowed = NULL;
349       impl->d_array          = NULL;
350       break;
351   }
352 
353   return CEED_ERROR_SUCCESS;
354 }
355 
356 //------------------------------------------------------------------------------
357 // Core logic for array syncronization for GetArray.
358 //   If a different memory type is most up to date, this will perform a copy
359 //------------------------------------------------------------------------------
360 static int CeedVectorGetArrayCore_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
361   Ceed ceed;
362   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
363   CeedVector_Cuda *impl;
364   CeedCallBackend(CeedVectorGetData(vec, &impl));
365 
366   // Sync array to requested mem_type
367   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
368 
369   // Update pointer
370   switch (mem_type) {
371     case CEED_MEM_HOST:
372       *array = impl->h_array;
373       break;
374     case CEED_MEM_DEVICE:
375       *array = impl->d_array;
376       break;
377   }
378 
379   return CEED_ERROR_SUCCESS;
380 }
381 
382 //------------------------------------------------------------------------------
383 // Get read-only access to a vector via the specified mem_type
384 //------------------------------------------------------------------------------
385 static int CeedVectorGetArrayRead_Cuda(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
386   return CeedVectorGetArrayCore_Cuda(vec, mem_type, (CeedScalar **)array);
387 }
388 
389 //------------------------------------------------------------------------------
390 // Get read/write access to a vector via the specified mem_type
391 //------------------------------------------------------------------------------
392 static int CeedVectorGetArray_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
393   CeedVector_Cuda *impl;
394   CeedCallBackend(CeedVectorGetData(vec, &impl));
395 
396   CeedCallBackend(CeedVectorGetArrayCore_Cuda(vec, mem_type, array));
397 
398   CeedCallBackend(CeedVectorSetAllInvalid_Cuda(vec));
399   switch (mem_type) {
400     case CEED_MEM_HOST:
401       impl->h_array = *array;
402       break;
403     case CEED_MEM_DEVICE:
404       impl->d_array = *array;
405       break;
406   }
407 
408   return CEED_ERROR_SUCCESS;
409 }
410 
411 //------------------------------------------------------------------------------
412 // Get write access to a vector via the specified mem_type
413 //------------------------------------------------------------------------------
414 static int CeedVectorGetArrayWrite_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
415   CeedVector_Cuda *impl;
416   CeedCallBackend(CeedVectorGetData(vec, &impl));
417 
418   bool has_array_of_type = true;
419   CeedCallBackend(CeedVectorHasArrayOfType_Cuda(vec, mem_type, &has_array_of_type));
420   if (!has_array_of_type) {
421     // Allocate if array is not yet allocated
422     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
423   } else {
424     // Select dirty array
425     switch (mem_type) {
426       case CEED_MEM_HOST:
427         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
428         else impl->h_array = impl->h_array_owned;
429         break;
430       case CEED_MEM_DEVICE:
431         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
432         else impl->d_array = impl->d_array_owned;
433     }
434   }
435 
436   return CeedVectorGetArray_Cuda(vec, mem_type, array);
437 }
438 
439 //------------------------------------------------------------------------------
440 // Get the norm of a CeedVector
441 //------------------------------------------------------------------------------
442 static int CeedVectorNorm_Cuda(CeedVector vec, CeedNormType type, CeedScalar *norm) {
443   Ceed ceed;
444   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
445   CeedVector_Cuda *impl;
446   CeedCallBackend(CeedVectorGetData(vec, &impl));
447   CeedSize length;
448   CeedCallBackend(CeedVectorGetLength(vec, &length));
449   cublasHandle_t handle;
450   CeedCallBackend(CeedGetCublasHandle_Cuda(ceed, &handle));
451 
452 #if CUDA_VERSION < 12000
453   // With CUDA 12, we can use the 64-bit integer interface. Prior to that,
454   // we need to check if the vector is too long to handle with int32,
455   // and if so, divide it into subsections for repeated cuBLAS calls.
456   CeedSize num_calls = length / INT_MAX;
457   if (length % INT_MAX > 0) num_calls += 1;
458 #endif
459 
460   // Compute norm
461   const CeedScalar *d_array;
462   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
463   switch (type) {
464     case CEED_NORM_1: {
465       *norm = 0.0;
466       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
467 #if CUDA_VERSION >= 12000  // We have CUDA 12, and can use 64-bit integers
468         CeedCallCublas(ceed, cublasSasum_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
469 #else
470         float  sub_norm = 0.0;
471         float *d_array_start;
472         for (CeedInt i = 0; i < num_calls; i++) {
473           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
474           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
475           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
476           CeedCallCublas(ceed, cublasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
477           *norm += sub_norm;
478         }
479 #endif
480       } else {
481 #if CUDA_VERSION >= 12000
482         CeedCallCublas(ceed, cublasDasum_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
483 #else
484         double  sub_norm = 0.0;
485         double *d_array_start;
486         for (CeedInt i = 0; i < num_calls; i++) {
487           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
488           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
489           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
490           CeedCallCublas(ceed, cublasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
491           *norm += sub_norm;
492         }
493 #endif
494       }
495       break;
496     }
497     case CEED_NORM_2: {
498       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
499 #if CUDA_VERSION >= 12000
500         CeedCallCublas(ceed, cublasSnrm2_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
501 #else
502         float  sub_norm = 0.0, norm_sum = 0.0;
503         float *d_array_start;
504         for (CeedInt i = 0; i < num_calls; i++) {
505           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
506           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
507           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
508           CeedCallCublas(ceed, cublasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
509           norm_sum += sub_norm * sub_norm;
510         }
511         *norm            = sqrt(norm_sum);
512 #endif
513       } else {
514 #if CUDA_VERSION >= 12000
515         CeedCallCublas(ceed, cublasDnrm2_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
516 #else
517         double  sub_norm = 0.0, norm_sum = 0.0;
518         double *d_array_start;
519         for (CeedInt i = 0; i < num_calls; i++) {
520           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
521           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
522           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
523           CeedCallCublas(ceed, cublasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
524           norm_sum += sub_norm * sub_norm;
525         }
526         *norm = sqrt(norm_sum);
527 #endif
528       }
529       break;
530     }
531     case CEED_NORM_MAX: {
532       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
533 #if CUDA_VERSION >= 12000
534         int64_t indx;
535         CeedCallCublas(ceed, cublasIsamax_64(handle, (int64_t)length, (float *)d_array, 1, &indx));
536         CeedScalar normNoAbs;
537         CeedCallCuda(ceed, cudaMemcpy(&normNoAbs, impl->d_array + indx - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
538         *norm = fabs(normNoAbs);
539 #else
540         CeedInt indx;
541         float   sub_max = 0.0, current_max = 0.0;
542         float  *d_array_start;
543         for (CeedInt i = 0; i < num_calls; i++) {
544           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
545           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
546           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
547           CeedCallCublas(ceed, cublasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &indx));
548           CeedCallCuda(ceed, cudaMemcpy(&sub_max, d_array_start + indx - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
549           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
550         }
551         *norm = current_max;
552 #endif
553       } else {
554 #if CUDA_VERSION >= 12000
555         int64_t indx;
556         CeedCallCublas(ceed, cublasIdamax_64(handle, (int64_t)length, (double *)d_array, 1, &indx));
557         CeedScalar normNoAbs;
558         CeedCallCuda(ceed, cudaMemcpy(&normNoAbs, impl->d_array + indx - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
559         *norm = fabs(normNoAbs);
560 #else
561         CeedInt indx;
562         double  sub_max = 0.0, current_max = 0.0;
563         double *d_array_start;
564         for (CeedInt i = 0; i < num_calls; i++) {
565           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
566           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
567           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
568           CeedCallCublas(ceed, cublasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &indx));
569           CeedCallCuda(ceed, cudaMemcpy(&sub_max, d_array_start + indx - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
570           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
571         }
572         *norm = current_max;
573 #endif
574       }
575       break;
576     }
577   }
578   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
579 
580   return CEED_ERROR_SUCCESS;
581 }
582 
583 //------------------------------------------------------------------------------
584 // Take reciprocal of a vector on host
585 //------------------------------------------------------------------------------
586 static int CeedHostReciprocal_Cuda(CeedScalar *h_array, CeedSize length) {
587   for (CeedSize i = 0; i < length; i++) {
588     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
589   }
590   return CEED_ERROR_SUCCESS;
591 }
592 
593 //------------------------------------------------------------------------------
594 // Take reciprocal of a vector on device (impl in .cu file)
595 //------------------------------------------------------------------------------
596 int CeedDeviceReciprocal_Cuda(CeedScalar *d_array, CeedSize length);
597 
598 //------------------------------------------------------------------------------
599 // Take reciprocal of a vector
600 //------------------------------------------------------------------------------
601 static int CeedVectorReciprocal_Cuda(CeedVector vec) {
602   Ceed ceed;
603   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
604   CeedVector_Cuda *impl;
605   CeedCallBackend(CeedVectorGetData(vec, &impl));
606   CeedSize length;
607   CeedCallBackend(CeedVectorGetLength(vec, &length));
608 
609   // Set value for synced device/host array
610   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Cuda(impl->d_array, length));
611   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Cuda(impl->h_array, length));
612 
613   return CEED_ERROR_SUCCESS;
614 }
615 
616 //------------------------------------------------------------------------------
617 // Compute x = alpha x on the host
618 //------------------------------------------------------------------------------
619 static int CeedHostScale_Cuda(CeedScalar *x_array, CeedScalar alpha, CeedSize length) {
620   for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha;
621   return CEED_ERROR_SUCCESS;
622 }
623 
624 //------------------------------------------------------------------------------
625 // Compute x = alpha x on device (impl in .cu file)
626 //------------------------------------------------------------------------------
627 int CeedDeviceScale_Cuda(CeedScalar *x_array, CeedScalar alpha, CeedSize length);
628 
629 //------------------------------------------------------------------------------
630 // Compute x = alpha x
631 //------------------------------------------------------------------------------
632 static int CeedVectorScale_Cuda(CeedVector x, CeedScalar alpha) {
633   Ceed ceed;
634   CeedCallBackend(CeedVectorGetCeed(x, &ceed));
635   CeedVector_Cuda *x_impl;
636   CeedCallBackend(CeedVectorGetData(x, &x_impl));
637   CeedSize length;
638   CeedCallBackend(CeedVectorGetLength(x, &length));
639 
640   // Set value for synced device/host array
641   if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Cuda(x_impl->d_array, alpha, length));
642   if (x_impl->h_array) CeedCallBackend(CeedHostScale_Cuda(x_impl->h_array, alpha, length));
643 
644   return CEED_ERROR_SUCCESS;
645 }
646 
647 //------------------------------------------------------------------------------
648 // Compute y = alpha x + y on the host
649 //------------------------------------------------------------------------------
650 static int CeedHostAXPY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) {
651   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
652   return CEED_ERROR_SUCCESS;
653 }
654 
655 //------------------------------------------------------------------------------
656 // Compute y = alpha x + y on device (impl in .cu file)
657 //------------------------------------------------------------------------------
658 int CeedDeviceAXPY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length);
659 
660 //------------------------------------------------------------------------------
661 // Compute y = alpha x + y
662 //------------------------------------------------------------------------------
663 static int CeedVectorAXPY_Cuda(CeedVector y, CeedScalar alpha, CeedVector x) {
664   Ceed ceed;
665   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
666   CeedVector_Cuda *y_impl, *x_impl;
667   CeedCallBackend(CeedVectorGetData(y, &y_impl));
668   CeedCallBackend(CeedVectorGetData(x, &x_impl));
669   CeedSize length;
670   CeedCallBackend(CeedVectorGetLength(y, &length));
671 
672   // Set value for synced device/host array
673   if (y_impl->d_array) {
674     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
675     CeedCallBackend(CeedDeviceAXPY_Cuda(y_impl->d_array, alpha, x_impl->d_array, length));
676   }
677   if (y_impl->h_array) {
678     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
679     CeedCallBackend(CeedHostAXPY_Cuda(y_impl->h_array, alpha, x_impl->h_array, length));
680   }
681 
682   return CEED_ERROR_SUCCESS;
683 }
684 
685 //------------------------------------------------------------------------------
686 // Compute y = alpha x + beta y on the host
687 //------------------------------------------------------------------------------
688 static int CeedHostAXPBY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) {
689   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i] + beta * y_array[i];
690   return CEED_ERROR_SUCCESS;
691 }
692 
693 //------------------------------------------------------------------------------
694 // Compute y = alpha x + beta y on device (impl in .cu file)
695 //------------------------------------------------------------------------------
696 int CeedDeviceAXPBY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length);
697 
698 //------------------------------------------------------------------------------
699 // Compute y = alpha x + beta y
700 //------------------------------------------------------------------------------
701 static int CeedVectorAXPBY_Cuda(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
702   Ceed ceed;
703   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
704   CeedVector_Cuda *y_impl, *x_impl;
705   CeedCallBackend(CeedVectorGetData(y, &y_impl));
706   CeedCallBackend(CeedVectorGetData(x, &x_impl));
707   CeedSize length;
708   CeedCallBackend(CeedVectorGetLength(y, &length));
709 
710   // Set value for synced device/host array
711   if (y_impl->d_array) {
712     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
713     CeedCallBackend(CeedDeviceAXPBY_Cuda(y_impl->d_array, alpha, beta, x_impl->d_array, length));
714   }
715   if (y_impl->h_array) {
716     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
717     CeedCallBackend(CeedHostAXPBY_Cuda(y_impl->h_array, alpha, beta, x_impl->h_array, length));
718   }
719 
720   return CEED_ERROR_SUCCESS;
721 }
722 
723 //------------------------------------------------------------------------------
724 // Compute the pointwise multiplication w = x .* y on the host
725 //------------------------------------------------------------------------------
726 static int CeedHostPointwiseMult_Cuda(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) {
727   for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
728   return CEED_ERROR_SUCCESS;
729 }
730 
731 //------------------------------------------------------------------------------
732 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
733 //------------------------------------------------------------------------------
734 int CeedDevicePointwiseMult_Cuda(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length);
735 
736 //------------------------------------------------------------------------------
737 // Compute the pointwise multiplication w = x .* y
738 //------------------------------------------------------------------------------
739 static int CeedVectorPointwiseMult_Cuda(CeedVector w, CeedVector x, CeedVector y) {
740   Ceed ceed;
741   CeedCallBackend(CeedVectorGetCeed(w, &ceed));
742   CeedVector_Cuda *w_impl, *x_impl, *y_impl;
743   CeedCallBackend(CeedVectorGetData(w, &w_impl));
744   CeedCallBackend(CeedVectorGetData(x, &x_impl));
745   CeedCallBackend(CeedVectorGetData(y, &y_impl));
746   CeedSize length;
747   CeedCallBackend(CeedVectorGetLength(w, &length));
748 
749   // Set value for synced device/host array
750   if (!w_impl->d_array && !w_impl->h_array) {
751     CeedCallBackend(CeedVectorSetValue(w, 0.0));
752   }
753   if (w_impl->d_array) {
754     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
755     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
756     CeedCallBackend(CeedDevicePointwiseMult_Cuda(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
757   }
758   if (w_impl->h_array) {
759     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
760     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
761     CeedCallBackend(CeedHostPointwiseMult_Cuda(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
762   }
763 
764   return CEED_ERROR_SUCCESS;
765 }
766 
767 //------------------------------------------------------------------------------
768 // Destroy the vector
769 //------------------------------------------------------------------------------
770 static int CeedVectorDestroy_Cuda(const CeedVector vec) {
771   Ceed ceed;
772   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
773   CeedVector_Cuda *impl;
774   CeedCallBackend(CeedVectorGetData(vec, &impl));
775 
776   CeedCallCuda(ceed, cudaFree(impl->d_array_owned));
777   CeedCallBackend(CeedFree(&impl->h_array_owned));
778   CeedCallBackend(CeedFree(&impl));
779 
780   return CEED_ERROR_SUCCESS;
781 }
782 
783 //------------------------------------------------------------------------------
784 // Create a vector of the specified length (does not allocate memory)
785 //------------------------------------------------------------------------------
786 int CeedVectorCreate_Cuda(CeedSize n, CeedVector vec) {
787   CeedVector_Cuda *impl;
788   Ceed             ceed;
789   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
790 
791   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Cuda));
792   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Cuda));
793   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Cuda));
794   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Cuda));
795   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", (int (*)())CeedVectorSetValue_Cuda));
796   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Cuda));
797   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Cuda));
798   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Cuda));
799   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Cuda));
800   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Cuda));
801   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Cuda));
802   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", (int (*)())CeedVectorScale_Cuda));
803   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", (int (*)())CeedVectorAXPY_Cuda));
804   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", (int (*)())CeedVectorAXPBY_Cuda));
805   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Cuda));
806   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Cuda));
807 
808   CeedCallBackend(CeedCalloc(1, &impl));
809   CeedCallBackend(CeedVectorSetData(vec, impl));
810 
811   return CEED_ERROR_SUCCESS;
812 }
813 
814 //------------------------------------------------------------------------------
815