xref: /libCEED/backends/cuda-ref/ceed-cuda-ref-vector.c (revision edb2538e3dd6743c029967fc4e89c6fcafedb8c2)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <cuda_runtime.h>
11 #include <math.h>
12 #include <stdbool.h>
13 #include <string.h>
14 
15 #include "../cuda/ceed-cuda-common.h"
16 #include "ceed-cuda-ref.h"
17 
18 //------------------------------------------------------------------------------
19 // Check if host/device sync is needed
20 //------------------------------------------------------------------------------
21 static inline int CeedVectorNeedSync_Cuda(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
22   bool             has_valid_array = false;
23   CeedVector_Cuda *impl;
24 
25   CeedCallBackend(CeedVectorGetData(vec, &impl));
26   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
27   switch (mem_type) {
28     case CEED_MEM_HOST:
29       *need_sync = has_valid_array && !impl->h_array;
30       break;
31     case CEED_MEM_DEVICE:
32       *need_sync = has_valid_array && !impl->d_array;
33       break;
34   }
35   return CEED_ERROR_SUCCESS;
36 }
37 
38 //------------------------------------------------------------------------------
39 // Sync host to device
40 //------------------------------------------------------------------------------
41 static inline int CeedVectorSyncH2D_Cuda(const CeedVector vec) {
42   Ceed             ceed;
43   CeedSize         length;
44   CeedVector_Cuda *impl;
45 
46   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
47   CeedCallBackend(CeedVectorGetData(vec, &impl));
48 
49   CeedCheck(impl->h_array, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
50 
51   CeedCallBackend(CeedVectorGetLength(vec, &length));
52   size_t bytes = length * sizeof(CeedScalar);
53 
54   if (impl->d_array_borrowed) {
55     impl->d_array = impl->d_array_borrowed;
56   } else if (impl->d_array_owned) {
57     impl->d_array = impl->d_array_owned;
58   } else {
59     CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_array_owned, bytes));
60     impl->d_array = impl->d_array_owned;
61   }
62   CeedCallCuda(ceed, cudaMemcpy(impl->d_array, impl->h_array, bytes, cudaMemcpyHostToDevice));
63   return CEED_ERROR_SUCCESS;
64 }
65 
66 //------------------------------------------------------------------------------
67 // Sync device to host
68 //------------------------------------------------------------------------------
69 static inline int CeedVectorSyncD2H_Cuda(const CeedVector vec) {
70   Ceed             ceed;
71   CeedSize         length;
72   CeedVector_Cuda *impl;
73 
74   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
75   CeedCallBackend(CeedVectorGetData(vec, &impl));
76 
77   CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
78 
79   if (impl->h_array_borrowed) {
80     impl->h_array = impl->h_array_borrowed;
81   } else if (impl->h_array_owned) {
82     impl->h_array = impl->h_array_owned;
83   } else {
84     CeedSize length;
85 
86     CeedCallBackend(CeedVectorGetLength(vec, &length));
87     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
88     impl->h_array = impl->h_array_owned;
89   }
90 
91   CeedCallBackend(CeedVectorGetLength(vec, &length));
92   size_t bytes = length * sizeof(CeedScalar);
93 
94   CeedCallCuda(ceed, cudaMemcpy(impl->h_array, impl->d_array, bytes, cudaMemcpyDeviceToHost));
95   return CEED_ERROR_SUCCESS;
96 }
97 
98 //------------------------------------------------------------------------------
99 // Sync arrays
100 //------------------------------------------------------------------------------
101 static int CeedVectorSyncArray_Cuda(const CeedVector vec, CeedMemType mem_type) {
102   bool need_sync = false;
103 
104   // Check whether device/host sync is needed
105   CeedCallBackend(CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync));
106   if (!need_sync) return CEED_ERROR_SUCCESS;
107 
108   switch (mem_type) {
109     case CEED_MEM_HOST:
110       return CeedVectorSyncD2H_Cuda(vec);
111     case CEED_MEM_DEVICE:
112       return CeedVectorSyncH2D_Cuda(vec);
113   }
114   return CEED_ERROR_UNSUPPORTED;
115 }
116 
117 //------------------------------------------------------------------------------
118 // Set all pointers as invalid
119 //------------------------------------------------------------------------------
120 static inline int CeedVectorSetAllInvalid_Cuda(const CeedVector vec) {
121   CeedVector_Cuda *impl;
122 
123   CeedCallBackend(CeedVectorGetData(vec, &impl));
124   impl->h_array = NULL;
125   impl->d_array = NULL;
126   return CEED_ERROR_SUCCESS;
127 }
128 
129 //------------------------------------------------------------------------------
130 // Check if CeedVector has any valid pointer
131 //------------------------------------------------------------------------------
132 static inline int CeedVectorHasValidArray_Cuda(const CeedVector vec, bool *has_valid_array) {
133   CeedVector_Cuda *impl;
134 
135   CeedCallBackend(CeedVectorGetData(vec, &impl));
136   *has_valid_array = impl->h_array || impl->d_array;
137   return CEED_ERROR_SUCCESS;
138 }
139 
140 //------------------------------------------------------------------------------
141 // Check if has array of given type
142 //------------------------------------------------------------------------------
143 static inline int CeedVectorHasArrayOfType_Cuda(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
144   CeedVector_Cuda *impl;
145 
146   CeedCallBackend(CeedVectorGetData(vec, &impl));
147   switch (mem_type) {
148     case CEED_MEM_HOST:
149       *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned;
150       break;
151     case CEED_MEM_DEVICE:
152       *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned;
153       break;
154   }
155   return CEED_ERROR_SUCCESS;
156 }
157 
158 //------------------------------------------------------------------------------
159 // Check if has borrowed array of given type
160 //------------------------------------------------------------------------------
161 static inline int CeedVectorHasBorrowedArrayOfType_Cuda(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
162   CeedVector_Cuda *impl;
163 
164   CeedCallBackend(CeedVectorGetData(vec, &impl));
165   switch (mem_type) {
166     case CEED_MEM_HOST:
167       *has_borrowed_array_of_type = impl->h_array_borrowed;
168       break;
169     case CEED_MEM_DEVICE:
170       *has_borrowed_array_of_type = impl->d_array_borrowed;
171       break;
172   }
173   return CEED_ERROR_SUCCESS;
174 }
175 
176 //------------------------------------------------------------------------------
177 // Set array from host
178 //------------------------------------------------------------------------------
179 static int CeedVectorSetArrayHost_Cuda(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
180   CeedVector_Cuda *impl;
181 
182   CeedCallBackend(CeedVectorGetData(vec, &impl));
183   switch (copy_mode) {
184     case CEED_COPY_VALUES: {
185       if (!impl->h_array_owned) {
186         CeedSize length;
187 
188         CeedCallBackend(CeedVectorGetLength(vec, &length));
189         CeedCallBackend(CeedMalloc(length, &impl->h_array_owned));
190       }
191       impl->h_array_borrowed = NULL;
192       impl->h_array          = impl->h_array_owned;
193       if (array) {
194         CeedSize length;
195 
196         CeedCallBackend(CeedVectorGetLength(vec, &length));
197         size_t bytes = length * sizeof(CeedScalar);
198         memcpy(impl->h_array, array, bytes);
199       }
200     } break;
201     case CEED_OWN_POINTER:
202       CeedCallBackend(CeedFree(&impl->h_array_owned));
203       impl->h_array_owned    = array;
204       impl->h_array_borrowed = NULL;
205       impl->h_array          = array;
206       break;
207     case CEED_USE_POINTER:
208       CeedCallBackend(CeedFree(&impl->h_array_owned));
209       impl->h_array_borrowed = array;
210       impl->h_array          = array;
211       break;
212   }
213   return CEED_ERROR_SUCCESS;
214 }
215 
216 //------------------------------------------------------------------------------
217 // Set array from device
218 //------------------------------------------------------------------------------
219 static int CeedVectorSetArrayDevice_Cuda(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
220   Ceed             ceed;
221   CeedVector_Cuda *impl;
222 
223   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
224   CeedCallBackend(CeedVectorGetData(vec, &impl));
225   switch (copy_mode) {
226     case CEED_COPY_VALUES: {
227       CeedSize length;
228 
229       CeedCallBackend(CeedVectorGetLength(vec, &length));
230       size_t bytes = length * sizeof(CeedScalar);
231 
232       if (!impl->d_array_owned) {
233         CeedCallCuda(ceed, cudaMalloc((void **)&impl->d_array_owned, bytes));
234       }
235       impl->d_array_borrowed = NULL;
236       impl->d_array          = impl->d_array_owned;
237       if (array) CeedCallCuda(ceed, cudaMemcpy(impl->d_array, array, bytes, cudaMemcpyDeviceToDevice));
238     } break;
239     case CEED_OWN_POINTER:
240       CeedCallCuda(ceed, cudaFree(impl->d_array_owned));
241       impl->d_array_owned    = array;
242       impl->d_array_borrowed = NULL;
243       impl->d_array          = array;
244       break;
245     case CEED_USE_POINTER:
246       CeedCallCuda(ceed, cudaFree(impl->d_array_owned));
247       impl->d_array_owned    = NULL;
248       impl->d_array_borrowed = array;
249       impl->d_array          = array;
250       break;
251   }
252   return CEED_ERROR_SUCCESS;
253 }
254 
255 //------------------------------------------------------------------------------
256 // Set the array used by a vector,
257 //   freeing any previously allocated array if applicable
258 //------------------------------------------------------------------------------
259 static int CeedVectorSetArray_Cuda(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
260   Ceed             ceed;
261   CeedVector_Cuda *impl;
262 
263   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
264   CeedCallBackend(CeedVectorGetData(vec, &impl));
265   CeedCallBackend(CeedVectorSetAllInvalid_Cuda(vec));
266   switch (mem_type) {
267     case CEED_MEM_HOST:
268       return CeedVectorSetArrayHost_Cuda(vec, copy_mode, array);
269     case CEED_MEM_DEVICE:
270       return CeedVectorSetArrayDevice_Cuda(vec, copy_mode, array);
271   }
272   return CEED_ERROR_UNSUPPORTED;
273 }
274 
275 //------------------------------------------------------------------------------
276 // Set host array to value
277 //------------------------------------------------------------------------------
278 static int CeedHostSetValue_Cuda(CeedScalar *h_array, CeedSize length, CeedScalar val) {
279   for (CeedSize i = 0; i < length; i++) h_array[i] = val;
280   return CEED_ERROR_SUCCESS;
281 }
282 
283 //------------------------------------------------------------------------------
284 // Set device array to value (impl in .cu file)
285 //------------------------------------------------------------------------------
286 int CeedDeviceSetValue_Cuda(CeedScalar *d_array, CeedSize length, CeedScalar val);
287 
288 //------------------------------------------------------------------------------
289 // Set a vector to a value
290 //------------------------------------------------------------------------------
291 static int CeedVectorSetValue_Cuda(CeedVector vec, CeedScalar val) {
292   Ceed             ceed;
293   CeedSize         length;
294   CeedVector_Cuda *impl;
295 
296   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
297   CeedCallBackend(CeedVectorGetData(vec, &impl));
298   CeedCallBackend(CeedVectorGetLength(vec, &length));
299   // Set value for synced device/host array
300   if (!impl->d_array && !impl->h_array) {
301     if (impl->d_array_borrowed) {
302       impl->d_array = impl->d_array_borrowed;
303     } else if (impl->h_array_borrowed) {
304       impl->h_array = impl->h_array_borrowed;
305     } else if (impl->d_array_owned) {
306       impl->d_array = impl->d_array_owned;
307     } else if (impl->h_array_owned) {
308       impl->h_array = impl->h_array_owned;
309     } else {
310       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
311     }
312   }
313   if (impl->d_array) {
314     CeedCallBackend(CeedDeviceSetValue_Cuda(impl->d_array, length, val));
315     impl->h_array = NULL;
316   }
317   if (impl->h_array) {
318     CeedCallBackend(CeedHostSetValue_Cuda(impl->h_array, length, val));
319     impl->d_array = NULL;
320   }
321   return CEED_ERROR_SUCCESS;
322 }
323 
324 //------------------------------------------------------------------------------
325 // Vector Take Array
326 //------------------------------------------------------------------------------
327 static int CeedVectorTakeArray_Cuda(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
328   Ceed             ceed;
329   CeedVector_Cuda *impl;
330 
331   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
332   CeedCallBackend(CeedVectorGetData(vec, &impl));
333   // Sync array to requested mem_type
334   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
335   // Update pointer
336   switch (mem_type) {
337     case CEED_MEM_HOST:
338       (*array)               = impl->h_array_borrowed;
339       impl->h_array_borrowed = NULL;
340       impl->h_array          = NULL;
341       break;
342     case CEED_MEM_DEVICE:
343       (*array)               = impl->d_array_borrowed;
344       impl->d_array_borrowed = NULL;
345       impl->d_array          = NULL;
346       break;
347   }
348   return CEED_ERROR_SUCCESS;
349 }
350 
351 //------------------------------------------------------------------------------
352 // Core logic for array syncronization for GetArray.
353 //   If a different memory type is most up to date, this will perform a copy
354 //------------------------------------------------------------------------------
355 static int CeedVectorGetArrayCore_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
356   Ceed             ceed;
357   CeedVector_Cuda *impl;
358 
359   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
360   CeedCallBackend(CeedVectorGetData(vec, &impl));
361   // Sync array to requested mem_type
362   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
363   // Update pointer
364   switch (mem_type) {
365     case CEED_MEM_HOST:
366       *array = impl->h_array;
367       break;
368     case CEED_MEM_DEVICE:
369       *array = impl->d_array;
370       break;
371   }
372   return CEED_ERROR_SUCCESS;
373 }
374 
375 //------------------------------------------------------------------------------
376 // Get read-only access to a vector via the specified mem_type
377 //------------------------------------------------------------------------------
378 static int CeedVectorGetArrayRead_Cuda(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
379   return CeedVectorGetArrayCore_Cuda(vec, mem_type, (CeedScalar **)array);
380 }
381 
382 //------------------------------------------------------------------------------
383 // Get read/write access to a vector via the specified mem_type
384 //------------------------------------------------------------------------------
385 static int CeedVectorGetArray_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
386   CeedVector_Cuda *impl;
387 
388   CeedCallBackend(CeedVectorGetData(vec, &impl));
389   CeedCallBackend(CeedVectorGetArrayCore_Cuda(vec, mem_type, array));
390   CeedCallBackend(CeedVectorSetAllInvalid_Cuda(vec));
391   switch (mem_type) {
392     case CEED_MEM_HOST:
393       impl->h_array = *array;
394       break;
395     case CEED_MEM_DEVICE:
396       impl->d_array = *array;
397       break;
398   }
399   return CEED_ERROR_SUCCESS;
400 }
401 
402 //------------------------------------------------------------------------------
403 // Get write access to a vector via the specified mem_type
404 //------------------------------------------------------------------------------
405 static int CeedVectorGetArrayWrite_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
406   bool             has_array_of_type = true;
407   CeedVector_Cuda *impl;
408 
409   CeedCallBackend(CeedVectorGetData(vec, &impl));
410   CeedCallBackend(CeedVectorHasArrayOfType_Cuda(vec, mem_type, &has_array_of_type));
411   if (!has_array_of_type) {
412     // Allocate if array is not yet allocated
413     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
414   } else {
415     // Select dirty array
416     switch (mem_type) {
417       case CEED_MEM_HOST:
418         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
419         else impl->h_array = impl->h_array_owned;
420         break;
421       case CEED_MEM_DEVICE:
422         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
423         else impl->d_array = impl->d_array_owned;
424     }
425   }
426   return CeedVectorGetArray_Cuda(vec, mem_type, array);
427 }
428 
429 //------------------------------------------------------------------------------
430 // Get the norm of a CeedVector
431 //------------------------------------------------------------------------------
432 static int CeedVectorNorm_Cuda(CeedVector vec, CeedNormType type, CeedScalar *norm) {
433   Ceed              ceed;
434   cublasHandle_t    handle;
435   CeedSize          length;
436   const CeedScalar *d_array;
437   CeedVector_Cuda  *impl;
438 
439   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
440   CeedCallBackend(CeedVectorGetData(vec, &impl));
441   CeedCallBackend(CeedVectorGetLength(vec, &length));
442   CeedCallBackend(CeedGetCublasHandle_Cuda(ceed, &handle));
443 
444 #if CUDA_VERSION < 12000
445   // With CUDA 12, we can use the 64-bit integer interface. Prior to that,
446   // we need to check if the vector is too long to handle with int32,
447   // and if so, divide it into subsections for repeated cuBLAS calls.
448   CeedSize num_calls = length / INT_MAX;
449 
450   if (length % INT_MAX > 0) num_calls += 1;
451 #endif
452 
453   // Compute norm
454   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
455   switch (type) {
456     case CEED_NORM_1: {
457       *norm = 0.0;
458       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
459 #if CUDA_VERSION >= 12000  // We have CUDA 12, and can use 64-bit integers
460         CeedCallCublas(ceed, cublasSasum_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
461 #else
462         float  sub_norm = 0.0;
463         float *d_array_start;
464 
465         for (CeedInt i = 0; i < num_calls; i++) {
466           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
467           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
468           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
469 
470           CeedCallCublas(ceed, cublasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
471           *norm += sub_norm;
472         }
473 #endif
474       } else {
475 #if CUDA_VERSION >= 12000
476         CeedCallCublas(ceed, cublasDasum_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
477 #else
478         double  sub_norm = 0.0;
479         double *d_array_start;
480 
481         for (CeedInt i = 0; i < num_calls; i++) {
482           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
483           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
484           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
485 
486           CeedCallCublas(ceed, cublasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
487           *norm += sub_norm;
488         }
489 #endif
490       }
491       break;
492     }
493     case CEED_NORM_2: {
494       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
495 #if CUDA_VERSION >= 12000
496         CeedCallCublas(ceed, cublasSnrm2_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
497 #else
498         float  sub_norm = 0.0, norm_sum = 0.0;
499         float *d_array_start;
500 
501         for (CeedInt i = 0; i < num_calls; i++) {
502           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
503           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
504           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
505 
506           CeedCallCublas(ceed, cublasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
507           norm_sum += sub_norm * sub_norm;
508         }
509         *norm            = sqrt(norm_sum);
510 #endif
511       } else {
512 #if CUDA_VERSION >= 12000
513         CeedCallCublas(ceed, cublasDnrm2_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
514 #else
515         double  sub_norm = 0.0, norm_sum = 0.0;
516         double *d_array_start;
517 
518         for (CeedInt i = 0; i < num_calls; i++) {
519           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
520           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
521           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
522 
523           CeedCallCublas(ceed, cublasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
524           norm_sum += sub_norm * sub_norm;
525         }
526         *norm = sqrt(norm_sum);
527 #endif
528       }
529       break;
530     }
531     case CEED_NORM_MAX: {
532       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
533 #if CUDA_VERSION >= 12000
534         int64_t    index;
535         CeedScalar norm_no_abs;
536 
537         CeedCallCublas(ceed, cublasIsamax_64(handle, (int64_t)length, (float *)d_array, 1, &index));
538         CeedCallCuda(ceed, cudaMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
539         *norm = fabs(norm_no_abs);
540 #else
541         CeedInt index;
542         float   sub_max = 0.0, current_max = 0.0;
543         float  *d_array_start;
544 
545         for (CeedInt i = 0; i < num_calls; i++) {
546           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
547           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
548           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
549 
550           CeedCallCublas(ceed, cublasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &index));
551           CeedCallCuda(ceed, cudaMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
552           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
553         }
554         *norm = current_max;
555 #endif
556       } else {
557 #if CUDA_VERSION >= 12000
558         int64_t    index;
559         CeedScalar norm_no_abs;
560 
561         CeedCallCublas(ceed, cublasIdamax_64(handle, (int64_t)length, (double *)d_array, 1, &index));
562         CeedCallCuda(ceed, cudaMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
563         *norm = fabs(norm_no_abs);
564 #else
565         CeedInt index;
566         double  sub_max = 0.0, current_max = 0.0;
567         double *d_array_start;
568 
569         for (CeedInt i = 0; i < num_calls; i++) {
570           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
571           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
572           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
573 
574           CeedCallCublas(ceed, cublasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &index));
575           CeedCallCuda(ceed, cudaMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
576           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
577         }
578         *norm = current_max;
579 #endif
580       }
581       break;
582     }
583   }
584   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
585   return CEED_ERROR_SUCCESS;
586 }
587 
588 //------------------------------------------------------------------------------
589 // Take reciprocal of a vector on host
590 //------------------------------------------------------------------------------
591 static int CeedHostReciprocal_Cuda(CeedScalar *h_array, CeedSize length) {
592   for (CeedSize i = 0; i < length; i++) {
593     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
594   }
595   return CEED_ERROR_SUCCESS;
596 }
597 
598 //------------------------------------------------------------------------------
599 // Take reciprocal of a vector on device (impl in .cu file)
600 //------------------------------------------------------------------------------
601 int CeedDeviceReciprocal_Cuda(CeedScalar *d_array, CeedSize length);
602 
603 //------------------------------------------------------------------------------
604 // Take reciprocal of a vector
605 //------------------------------------------------------------------------------
606 static int CeedVectorReciprocal_Cuda(CeedVector vec) {
607   Ceed             ceed;
608   CeedSize         length;
609   CeedVector_Cuda *impl;
610 
611   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
612   CeedCallBackend(CeedVectorGetData(vec, &impl));
613   CeedCallBackend(CeedVectorGetLength(vec, &length));
614   // Set value for synced device/host array
615   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Cuda(impl->d_array, length));
616   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Cuda(impl->h_array, length));
617   return CEED_ERROR_SUCCESS;
618 }
619 
620 //------------------------------------------------------------------------------
621 // Compute x = alpha x on the host
622 //------------------------------------------------------------------------------
623 static int CeedHostScale_Cuda(CeedScalar *x_array, CeedScalar alpha, CeedSize length) {
624   for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha;
625   return CEED_ERROR_SUCCESS;
626 }
627 
628 //------------------------------------------------------------------------------
629 // Compute x = alpha x on device (impl in .cu file)
630 //------------------------------------------------------------------------------
631 int CeedDeviceScale_Cuda(CeedScalar *x_array, CeedScalar alpha, CeedSize length);
632 
633 //------------------------------------------------------------------------------
634 // Compute x = alpha x
635 //------------------------------------------------------------------------------
636 static int CeedVectorScale_Cuda(CeedVector x, CeedScalar alpha) {
637   Ceed             ceed;
638   CeedSize         length;
639   CeedVector_Cuda *x_impl;
640 
641   CeedCallBackend(CeedVectorGetCeed(x, &ceed));
642   CeedCallBackend(CeedVectorGetData(x, &x_impl));
643   CeedCallBackend(CeedVectorGetLength(x, &length));
644   // Set value for synced device/host array
645   if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Cuda(x_impl->d_array, alpha, length));
646   if (x_impl->h_array) CeedCallBackend(CeedHostScale_Cuda(x_impl->h_array, alpha, length));
647   return CEED_ERROR_SUCCESS;
648 }
649 
650 //------------------------------------------------------------------------------
651 // Compute y = alpha x + y on the host
652 //------------------------------------------------------------------------------
653 static int CeedHostAXPY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) {
654   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
655   return CEED_ERROR_SUCCESS;
656 }
657 
658 //------------------------------------------------------------------------------
659 // Compute y = alpha x + y on device (impl in .cu file)
660 //------------------------------------------------------------------------------
661 int CeedDeviceAXPY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length);
662 
663 //------------------------------------------------------------------------------
664 // Compute y = alpha x + y
665 //------------------------------------------------------------------------------
666 static int CeedVectorAXPY_Cuda(CeedVector y, CeedScalar alpha, CeedVector x) {
667   Ceed             ceed;
668   CeedSize         length;
669   CeedVector_Cuda *y_impl, *x_impl;
670 
671   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
672   CeedCallBackend(CeedVectorGetData(y, &y_impl));
673   CeedCallBackend(CeedVectorGetData(x, &x_impl));
674   CeedCallBackend(CeedVectorGetLength(y, &length));
675   // Set value for synced device/host array
676   if (y_impl->d_array) {
677     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
678     CeedCallBackend(CeedDeviceAXPY_Cuda(y_impl->d_array, alpha, x_impl->d_array, length));
679   }
680   if (y_impl->h_array) {
681     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
682     CeedCallBackend(CeedHostAXPY_Cuda(y_impl->h_array, alpha, x_impl->h_array, length));
683   }
684   return CEED_ERROR_SUCCESS;
685 }
686 
687 //------------------------------------------------------------------------------
688 // Compute y = alpha x + beta y on the host
689 //------------------------------------------------------------------------------
690 static int CeedHostAXPBY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) {
691   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i] + beta * y_array[i];
692   return CEED_ERROR_SUCCESS;
693 }
694 
695 //------------------------------------------------------------------------------
696 // Compute y = alpha x + beta y on device (impl in .cu file)
697 //------------------------------------------------------------------------------
698 int CeedDeviceAXPBY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length);
699 
700 //------------------------------------------------------------------------------
701 // Compute y = alpha x + beta y
702 //------------------------------------------------------------------------------
703 static int CeedVectorAXPBY_Cuda(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
704   Ceed             ceed;
705   CeedSize         length;
706   CeedVector_Cuda *y_impl, *x_impl;
707 
708   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
709   CeedCallBackend(CeedVectorGetData(y, &y_impl));
710   CeedCallBackend(CeedVectorGetData(x, &x_impl));
711   CeedCallBackend(CeedVectorGetLength(y, &length));
712   // Set value for synced device/host array
713   if (y_impl->d_array) {
714     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
715     CeedCallBackend(CeedDeviceAXPBY_Cuda(y_impl->d_array, alpha, beta, x_impl->d_array, length));
716   }
717   if (y_impl->h_array) {
718     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
719     CeedCallBackend(CeedHostAXPBY_Cuda(y_impl->h_array, alpha, beta, x_impl->h_array, length));
720   }
721   return CEED_ERROR_SUCCESS;
722 }
723 
724 //------------------------------------------------------------------------------
725 // Compute the pointwise multiplication w = x .* y on the host
726 //------------------------------------------------------------------------------
727 static int CeedHostPointwiseMult_Cuda(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) {
728   for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
729   return CEED_ERROR_SUCCESS;
730 }
731 
732 //------------------------------------------------------------------------------
733 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
734 //------------------------------------------------------------------------------
735 int CeedDevicePointwiseMult_Cuda(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length);
736 
737 //------------------------------------------------------------------------------
738 // Compute the pointwise multiplication w = x .* y
739 //------------------------------------------------------------------------------
740 static int CeedVectorPointwiseMult_Cuda(CeedVector w, CeedVector x, CeedVector y) {
741   Ceed             ceed;
742   CeedSize         length;
743   CeedVector_Cuda *w_impl, *x_impl, *y_impl;
744 
745   CeedCallBackend(CeedVectorGetCeed(w, &ceed));
746   CeedCallBackend(CeedVectorGetData(w, &w_impl));
747   CeedCallBackend(CeedVectorGetData(x, &x_impl));
748   CeedCallBackend(CeedVectorGetData(y, &y_impl));
749   CeedCallBackend(CeedVectorGetLength(w, &length));
750   // Set value for synced device/host array
751   if (!w_impl->d_array && !w_impl->h_array) {
752     CeedCallBackend(CeedVectorSetValue(w, 0.0));
753   }
754   if (w_impl->d_array) {
755     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
756     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
757     CeedCallBackend(CeedDevicePointwiseMult_Cuda(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
758   }
759   if (w_impl->h_array) {
760     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
761     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
762     CeedCallBackend(CeedHostPointwiseMult_Cuda(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
763   }
764   return CEED_ERROR_SUCCESS;
765 }
766 
767 //------------------------------------------------------------------------------
768 // Destroy the vector
769 //------------------------------------------------------------------------------
770 static int CeedVectorDestroy_Cuda(const CeedVector vec) {
771   Ceed             ceed;
772   CeedVector_Cuda *impl;
773 
774   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
775   CeedCallBackend(CeedVectorGetData(vec, &impl));
776   CeedCallCuda(ceed, cudaFree(impl->d_array_owned));
777   CeedCallBackend(CeedFree(&impl->h_array_owned));
778   CeedCallBackend(CeedFree(&impl));
779   return CEED_ERROR_SUCCESS;
780 }
781 
782 //------------------------------------------------------------------------------
783 // Create a vector of the specified length (does not allocate memory)
784 //------------------------------------------------------------------------------
785 int CeedVectorCreate_Cuda(CeedSize n, CeedVector vec) {
786   CeedVector_Cuda *impl;
787   Ceed             ceed;
788 
789   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
790   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Cuda));
791   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Cuda));
792   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Cuda));
793   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Cuda));
794   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", (int (*)())CeedVectorSetValue_Cuda));
795   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Cuda));
796   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Cuda));
797   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Cuda));
798   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Cuda));
799   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Cuda));
800   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Cuda));
801   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", (int (*)())CeedVectorScale_Cuda));
802   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", (int (*)())CeedVectorAXPY_Cuda));
803   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", (int (*)())CeedVectorAXPBY_Cuda));
804   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Cuda));
805   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Cuda));
806   CeedCallBackend(CeedCalloc(1, &impl));
807   CeedCallBackend(CeedVectorSetData(vec, impl));
808   return CEED_ERROR_SUCCESS;
809 }
810 
811 //------------------------------------------------------------------------------
812