xref: /libCEED/backends/cuda-ref/ceed-cuda-ref-vector.c (revision cc3bdf8cd7a97c52371610b7fbb458a86c2b0cc9)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <cuda_runtime.h>
11 #include <math.h>
12 #include <stdbool.h>
13 #include <string.h>
14 
15 #include "../cuda/ceed-cuda-common.h"
16 #include "ceed-cuda-ref.h"
17 
18 //------------------------------------------------------------------------------
19 // Check if host/device sync is needed
20 //------------------------------------------------------------------------------
21 static inline int CeedVectorNeedSync_Cuda(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
22   bool             has_valid_array = false;
23   CeedVector_Cuda *impl;
24 
25   CeedCallBackend(CeedVectorGetData(vec, &impl));
26   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
27   switch (mem_type) {
28     case CEED_MEM_HOST:
29       *need_sync = has_valid_array && !impl->h_array;
30       break;
31     case CEED_MEM_DEVICE:
32       *need_sync = has_valid_array && !impl->d_array;
33       break;
34   }
35   return CEED_ERROR_SUCCESS;
36 }
37 
38 //------------------------------------------------------------------------------
39 // Sync host to device
40 //------------------------------------------------------------------------------
41 static inline int CeedVectorSyncH2D_Cuda(const CeedVector vec) {
42   CeedSize         length;
43   size_t           bytes;
44   CeedVector_Cuda *impl;
45 
46   CeedCallBackend(CeedVectorGetData(vec, &impl));
47 
48   CeedCheck(impl->h_array, CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "No valid host data to sync to device");
49 
50   CeedCallBackend(CeedVectorGetLength(vec, &length));
51   bytes = length * sizeof(CeedScalar);
52   if (impl->d_array_borrowed) {
53     impl->d_array = impl->d_array_borrowed;
54   } else if (impl->d_array_owned) {
55     impl->d_array = impl->d_array_owned;
56   } else {
57     CeedCallCuda(CeedVectorReturnCeed(vec), cudaMalloc((void **)&impl->d_array_owned, bytes));
58     impl->d_array = impl->d_array_owned;
59   }
60   CeedCallCuda(CeedVectorReturnCeed(vec), cudaMemcpy(impl->d_array, impl->h_array, bytes, cudaMemcpyHostToDevice));
61   return CEED_ERROR_SUCCESS;
62 }
63 
64 //------------------------------------------------------------------------------
65 // Sync device to host
66 //------------------------------------------------------------------------------
67 static inline int CeedVectorSyncD2H_Cuda(const CeedVector vec) {
68   CeedSize         length;
69   CeedVector_Cuda *impl;
70 
71   CeedCallBackend(CeedVectorGetData(vec, &impl));
72 
73   CeedCheck(impl->d_array, CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "No valid device data to sync to host");
74 
75   if (impl->h_array_borrowed) {
76     impl->h_array = impl->h_array_borrowed;
77   } else if (impl->h_array_owned) {
78     impl->h_array = impl->h_array_owned;
79   } else {
80     CeedSize length;
81 
82     CeedCallBackend(CeedVectorGetLength(vec, &length));
83     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
84     impl->h_array = impl->h_array_owned;
85   }
86 
87   CeedCallBackend(CeedVectorGetLength(vec, &length));
88   size_t bytes = length * sizeof(CeedScalar);
89 
90   CeedCallCuda(CeedVectorReturnCeed(vec), cudaMemcpy(impl->h_array, impl->d_array, bytes, cudaMemcpyDeviceToHost));
91   return CEED_ERROR_SUCCESS;
92 }
93 
94 //------------------------------------------------------------------------------
95 // Sync arrays
96 //------------------------------------------------------------------------------
97 static int CeedVectorSyncArray_Cuda(const CeedVector vec, CeedMemType mem_type) {
98   bool need_sync = false;
99 
100   // Check whether device/host sync is needed
101   CeedCallBackend(CeedVectorNeedSync_Cuda(vec, mem_type, &need_sync));
102   if (!need_sync) return CEED_ERROR_SUCCESS;
103 
104   switch (mem_type) {
105     case CEED_MEM_HOST:
106       return CeedVectorSyncD2H_Cuda(vec);
107     case CEED_MEM_DEVICE:
108       return CeedVectorSyncH2D_Cuda(vec);
109   }
110   return CEED_ERROR_UNSUPPORTED;
111 }
112 
113 //------------------------------------------------------------------------------
114 // Set all pointers as invalid
115 //------------------------------------------------------------------------------
116 static inline int CeedVectorSetAllInvalid_Cuda(const CeedVector vec) {
117   CeedVector_Cuda *impl;
118 
119   CeedCallBackend(CeedVectorGetData(vec, &impl));
120   impl->h_array = NULL;
121   impl->d_array = NULL;
122   return CEED_ERROR_SUCCESS;
123 }
124 
125 //------------------------------------------------------------------------------
126 // Check if CeedVector has any valid pointer
127 //------------------------------------------------------------------------------
128 static inline int CeedVectorHasValidArray_Cuda(const CeedVector vec, bool *has_valid_array) {
129   CeedVector_Cuda *impl;
130 
131   CeedCallBackend(CeedVectorGetData(vec, &impl));
132   *has_valid_array = impl->h_array || impl->d_array;
133   return CEED_ERROR_SUCCESS;
134 }
135 
136 //------------------------------------------------------------------------------
137 // Check if has array of given type
138 //------------------------------------------------------------------------------
139 static inline int CeedVectorHasArrayOfType_Cuda(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
140   CeedVector_Cuda *impl;
141 
142   CeedCallBackend(CeedVectorGetData(vec, &impl));
143   switch (mem_type) {
144     case CEED_MEM_HOST:
145       *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned;
146       break;
147     case CEED_MEM_DEVICE:
148       *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned;
149       break;
150   }
151   return CEED_ERROR_SUCCESS;
152 }
153 
154 //------------------------------------------------------------------------------
155 // Check if has borrowed array of given type
156 //------------------------------------------------------------------------------
157 static inline int CeedVectorHasBorrowedArrayOfType_Cuda(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
158   CeedVector_Cuda *impl;
159 
160   CeedCallBackend(CeedVectorGetData(vec, &impl));
161   switch (mem_type) {
162     case CEED_MEM_HOST:
163       *has_borrowed_array_of_type = impl->h_array_borrowed;
164       break;
165     case CEED_MEM_DEVICE:
166       *has_borrowed_array_of_type = impl->d_array_borrowed;
167       break;
168   }
169   return CEED_ERROR_SUCCESS;
170 }
171 
172 //------------------------------------------------------------------------------
173 // Set array from host
174 //------------------------------------------------------------------------------
175 static int CeedVectorSetArrayHost_Cuda(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
176   CeedSize         length;
177   CeedVector_Cuda *impl;
178 
179   CeedCallBackend(CeedVectorGetData(vec, &impl));
180   CeedCallBackend(CeedVectorGetLength(vec, &length));
181 
182   CeedCallBackend(CeedSetHostCeedScalarArray(array, copy_mode, length, (const CeedScalar **)&impl->h_array_owned,
183                                              (const CeedScalar **)&impl->h_array_borrowed, (const CeedScalar **)&impl->h_array));
184   return CEED_ERROR_SUCCESS;
185 }
186 
187 //------------------------------------------------------------------------------
188 // Set array from device
189 //------------------------------------------------------------------------------
190 static int CeedVectorSetArrayDevice_Cuda(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
191   CeedSize         length;
192   Ceed             ceed;
193   CeedVector_Cuda *impl;
194 
195   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
196   CeedCallBackend(CeedVectorGetData(vec, &impl));
197   CeedCallBackend(CeedVectorGetLength(vec, &length));
198 
199   CeedCallBackend(CeedSetDeviceCeedScalarArray_Cuda(ceed, array, copy_mode, length, (const CeedScalar **)&impl->d_array_owned,
200                                                     (const CeedScalar **)&impl->d_array_borrowed, (const CeedScalar **)&impl->d_array));
201   CeedCallBackend(CeedDestroy(&ceed));
202   return CEED_ERROR_SUCCESS;
203 }
204 
205 //------------------------------------------------------------------------------
206 // Set the array used by a vector,
207 //   freeing any previously allocated array if applicable
208 //------------------------------------------------------------------------------
209 static int CeedVectorSetArray_Cuda(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
210   CeedVector_Cuda *impl;
211 
212   CeedCallBackend(CeedVectorGetData(vec, &impl));
213   CeedCallBackend(CeedVectorSetAllInvalid_Cuda(vec));
214   switch (mem_type) {
215     case CEED_MEM_HOST:
216       return CeedVectorSetArrayHost_Cuda(vec, copy_mode, array);
217     case CEED_MEM_DEVICE:
218       return CeedVectorSetArrayDevice_Cuda(vec, copy_mode, array);
219   }
220   return CEED_ERROR_UNSUPPORTED;
221 }
222 
223 //------------------------------------------------------------------------------
224 // Copy host array to value strided
225 //------------------------------------------------------------------------------
226 static int CeedHostCopyStrided_Cuda(CeedScalar *h_array, CeedSize start, CeedSize step, CeedSize length, CeedScalar *h_copy_array) {
227   for (CeedSize i = start; i < length; i += step) h_copy_array[i] = h_array[i];
228   return CEED_ERROR_SUCCESS;
229 }
230 
231 //------------------------------------------------------------------------------
232 // Copy device array to value strided (impl in .cu file)
233 //------------------------------------------------------------------------------
234 int CeedDeviceCopyStrided_Cuda(CeedScalar *d_array, CeedSize start, CeedSize step, CeedSize length, CeedScalar *d_copy_array);
235 
236 //------------------------------------------------------------------------------
237 // Copy a vector to a value strided
238 //------------------------------------------------------------------------------
239 static int CeedVectorCopyStrided_Cuda(CeedVector vec, CeedSize start, CeedSize step, CeedVector vec_copy) {
240   CeedSize         length;
241   CeedVector_Cuda *impl;
242 
243   CeedCallBackend(CeedVectorGetData(vec, &impl));
244   {
245     CeedSize length_vec, length_copy;
246 
247     CeedCallBackend(CeedVectorGetLength(vec, &length_vec));
248     CeedCallBackend(CeedVectorGetLength(vec_copy, &length_copy));
249     length = length_vec < length_copy ? length_vec : length_copy;
250   }
251   // Set value for synced device/host array
252   if (impl->d_array) {
253     CeedScalar *copy_array;
254 
255     CeedCallBackend(CeedVectorGetArray(vec_copy, CEED_MEM_DEVICE, &copy_array));
256     CeedCallBackend(CeedDeviceCopyStrided_Cuda(impl->d_array, start, step, length, copy_array));
257     CeedCallBackend(CeedVectorRestoreArray(vec_copy, &copy_array));
258   } else if (impl->h_array) {
259     CeedScalar *copy_array;
260 
261     CeedCallBackend(CeedVectorGetArray(vec_copy, CEED_MEM_HOST, &copy_array));
262     CeedCallBackend(CeedHostCopyStrided_Cuda(impl->h_array, start, step, length, copy_array));
263     CeedCallBackend(CeedVectorRestoreArray(vec_copy, &copy_array));
264   } else {
265     return CeedError(CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "CeedVector must have valid data set");
266   }
267   return CEED_ERROR_SUCCESS;
268 }
269 
270 //------------------------------------------------------------------------------
271 // Set host array to value
272 //------------------------------------------------------------------------------
273 static int CeedHostSetValue_Cuda(CeedScalar *h_array, CeedSize length, CeedScalar val) {
274   for (CeedSize i = 0; i < length; i++) h_array[i] = val;
275   return CEED_ERROR_SUCCESS;
276 }
277 
278 //------------------------------------------------------------------------------
279 // Set device array to value (impl in .cu file)
280 //------------------------------------------------------------------------------
281 int CeedDeviceSetValue_Cuda(CeedScalar *d_array, CeedSize length, CeedScalar val);
282 
283 //------------------------------------------------------------------------------
284 // Set a vector to a value
285 //------------------------------------------------------------------------------
286 static int CeedVectorSetValue_Cuda(CeedVector vec, CeedScalar val) {
287   CeedSize         length;
288   CeedVector_Cuda *impl;
289 
290   CeedCallBackend(CeedVectorGetData(vec, &impl));
291   CeedCallBackend(CeedVectorGetLength(vec, &length));
292   // Set value for synced device/host array
293   if (!impl->d_array && !impl->h_array) {
294     if (impl->d_array_borrowed) {
295       impl->d_array = impl->d_array_borrowed;
296     } else if (impl->h_array_borrowed) {
297       impl->h_array = impl->h_array_borrowed;
298     } else if (impl->d_array_owned) {
299       impl->d_array = impl->d_array_owned;
300     } else if (impl->h_array_owned) {
301       impl->h_array = impl->h_array_owned;
302     } else {
303       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
304     }
305   }
306   if (impl->d_array) {
307     CeedCallBackend(CeedDeviceSetValue_Cuda(impl->d_array, length, val));
308     impl->h_array = NULL;
309   }
310   if (impl->h_array) {
311     CeedCallBackend(CeedHostSetValue_Cuda(impl->h_array, length, val));
312     impl->d_array = NULL;
313   }
314   return CEED_ERROR_SUCCESS;
315 }
316 
317 //------------------------------------------------------------------------------
318 // Set host array to value strided
319 //------------------------------------------------------------------------------
320 static int CeedHostSetValueStrided_Cuda(CeedScalar *h_array, CeedSize start, CeedSize step, CeedSize length, CeedScalar val) {
321   for (CeedSize i = start; i < length; i += step) h_array[i] = val;
322   return CEED_ERROR_SUCCESS;
323 }
324 
325 //------------------------------------------------------------------------------
326 // Set device array to value strided (impl in .cu file)
327 //------------------------------------------------------------------------------
328 int CeedDeviceSetValueStrided_Cuda(CeedScalar *d_array, CeedSize start, CeedSize step, CeedSize length, CeedScalar val);
329 
330 //------------------------------------------------------------------------------
331 // Set a vector to a value strided
332 //------------------------------------------------------------------------------
333 static int CeedVectorSetValueStrided_Cuda(CeedVector vec, CeedSize start, CeedSize step, CeedScalar val) {
334   CeedSize         length;
335   CeedVector_Cuda *impl;
336 
337   CeedCallBackend(CeedVectorGetData(vec, &impl));
338   CeedCallBackend(CeedVectorGetLength(vec, &length));
339   // Set value for synced device/host array
340   if (impl->d_array) {
341     CeedCallBackend(CeedDeviceSetValueStrided_Cuda(impl->d_array, start, step, length, val));
342     impl->h_array = NULL;
343   } else if (impl->h_array) {
344     CeedCallBackend(CeedHostSetValueStrided_Cuda(impl->h_array, start, step, length, val));
345     impl->d_array = NULL;
346   } else {
347     return CeedError(CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "CeedVector must have valid data set");
348   }
349   return CEED_ERROR_SUCCESS;
350 }
351 
352 //------------------------------------------------------------------------------
353 // Vector Take Array
354 //------------------------------------------------------------------------------
355 static int CeedVectorTakeArray_Cuda(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
356   CeedVector_Cuda *impl;
357 
358   CeedCallBackend(CeedVectorGetData(vec, &impl));
359   // Sync array to requested mem_type
360   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
361   // Update pointer
362   switch (mem_type) {
363     case CEED_MEM_HOST:
364       (*array)               = impl->h_array_borrowed;
365       impl->h_array_borrowed = NULL;
366       impl->h_array          = NULL;
367       break;
368     case CEED_MEM_DEVICE:
369       (*array)               = impl->d_array_borrowed;
370       impl->d_array_borrowed = NULL;
371       impl->d_array          = NULL;
372       break;
373   }
374   return CEED_ERROR_SUCCESS;
375 }
376 
377 //------------------------------------------------------------------------------
378 // Core logic for array syncronization for GetArray.
379 //   If a different memory type is most up to date, this will perform a copy
380 //------------------------------------------------------------------------------
381 static int CeedVectorGetArrayCore_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
382   CeedVector_Cuda *impl;
383 
384   CeedCallBackend(CeedVectorGetData(vec, &impl));
385   // Sync array to requested mem_type
386   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
387   // Update pointer
388   switch (mem_type) {
389     case CEED_MEM_HOST:
390       *array = impl->h_array;
391       break;
392     case CEED_MEM_DEVICE:
393       *array = impl->d_array;
394       break;
395   }
396   return CEED_ERROR_SUCCESS;
397 }
398 
399 //------------------------------------------------------------------------------
400 // Get read-only access to a vector via the specified mem_type
401 //------------------------------------------------------------------------------
402 static int CeedVectorGetArrayRead_Cuda(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
403   return CeedVectorGetArrayCore_Cuda(vec, mem_type, (CeedScalar **)array);
404 }
405 
406 //------------------------------------------------------------------------------
407 // Get read/write access to a vector via the specified mem_type
408 //------------------------------------------------------------------------------
409 static int CeedVectorGetArray_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
410   CeedVector_Cuda *impl;
411 
412   CeedCallBackend(CeedVectorGetData(vec, &impl));
413   CeedCallBackend(CeedVectorGetArrayCore_Cuda(vec, mem_type, array));
414   CeedCallBackend(CeedVectorSetAllInvalid_Cuda(vec));
415   switch (mem_type) {
416     case CEED_MEM_HOST:
417       impl->h_array = *array;
418       break;
419     case CEED_MEM_DEVICE:
420       impl->d_array = *array;
421       break;
422   }
423   return CEED_ERROR_SUCCESS;
424 }
425 
426 //------------------------------------------------------------------------------
427 // Get write access to a vector via the specified mem_type
428 //------------------------------------------------------------------------------
429 static int CeedVectorGetArrayWrite_Cuda(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
430   bool             has_array_of_type = true;
431   CeedVector_Cuda *impl;
432 
433   CeedCallBackend(CeedVectorGetData(vec, &impl));
434   CeedCallBackend(CeedVectorHasArrayOfType_Cuda(vec, mem_type, &has_array_of_type));
435   if (!has_array_of_type) {
436     // Allocate if array is not yet allocated
437     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
438   } else {
439     // Select dirty array
440     switch (mem_type) {
441       case CEED_MEM_HOST:
442         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
443         else impl->h_array = impl->h_array_owned;
444         break;
445       case CEED_MEM_DEVICE:
446         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
447         else impl->d_array = impl->d_array_owned;
448     }
449   }
450   return CeedVectorGetArray_Cuda(vec, mem_type, array);
451 }
452 
453 //------------------------------------------------------------------------------
454 // Get the norm of a CeedVector
455 //------------------------------------------------------------------------------
456 static int CeedVectorNorm_Cuda(CeedVector vec, CeedNormType type, CeedScalar *norm) {
457   Ceed     ceed;
458   CeedSize length;
459 #if CUDA_VERSION < 12000
460   CeedSize num_calls;
461 #endif
462   const CeedScalar *d_array;
463   CeedVector_Cuda  *impl;
464   cublasHandle_t    handle;
465 
466   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
467   CeedCallBackend(CeedVectorGetData(vec, &impl));
468   CeedCallBackend(CeedVectorGetLength(vec, &length));
469   CeedCallBackend(CeedGetCublasHandle_Cuda(ceed, &handle));
470 
471 #if CUDA_VERSION < 12000
472   // With CUDA 12, we can use the 64-bit integer interface. Prior to that,
473   // we need to check if the vector is too long to handle with int32,
474   // and if so, divide it into subsections for repeated cuBLAS calls.
475   num_calls = length / INT_MAX;
476   if (length % INT_MAX > 0) num_calls += 1;
477 #endif
478 
479   // Compute norm
480   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
481   switch (type) {
482     case CEED_NORM_1: {
483       *norm = 0.0;
484       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
485 #if CUDA_VERSION >= 12000  // We have CUDA 12, and can use 64-bit integers
486         CeedCallCublas(ceed, cublasSasum_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
487 #else
488         float  sub_norm = 0.0;
489         float *d_array_start;
490 
491         for (CeedInt i = 0; i < num_calls; i++) {
492           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
493           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
494           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
495 
496           CeedCallCublas(ceed, cublasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
497           *norm += sub_norm;
498         }
499 #endif
500       } else {
501 #if CUDA_VERSION >= 12000
502         CeedCallCublas(ceed, cublasDasum_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
503 #else
504         double  sub_norm = 0.0;
505         double *d_array_start;
506 
507         for (CeedInt i = 0; i < num_calls; i++) {
508           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
509           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
510           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
511 
512           CeedCallCublas(ceed, cublasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
513           *norm += sub_norm;
514         }
515 #endif
516       }
517       break;
518     }
519     case CEED_NORM_2: {
520       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
521 #if CUDA_VERSION >= 12000
522         CeedCallCublas(ceed, cublasSnrm2_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
523 #else
524         float  sub_norm = 0.0, norm_sum = 0.0;
525         float *d_array_start;
526 
527         for (CeedInt i = 0; i < num_calls; i++) {
528           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
529           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
530           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
531 
532           CeedCallCublas(ceed, cublasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
533           norm_sum += sub_norm * sub_norm;
534         }
535         *norm = sqrt(norm_sum);
536 #endif
537       } else {
538 #if CUDA_VERSION >= 12000
539         CeedCallCublas(ceed, cublasDnrm2_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
540 #else
541         double  sub_norm = 0.0, norm_sum = 0.0;
542         double *d_array_start;
543 
544         for (CeedInt i = 0; i < num_calls; i++) {
545           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
546           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
547           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
548 
549           CeedCallCublas(ceed, cublasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
550           norm_sum += sub_norm * sub_norm;
551         }
552         *norm = sqrt(norm_sum);
553 #endif
554       }
555       break;
556     }
557     case CEED_NORM_MAX: {
558       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
559 #if CUDA_VERSION >= 12000
560         int64_t    index;
561         CeedScalar norm_no_abs;
562 
563         CeedCallCublas(ceed, cublasIsamax_64(handle, (int64_t)length, (float *)d_array, 1, &index));
564         CeedCallCuda(ceed, cudaMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
565         *norm = fabs(norm_no_abs);
566 #else
567         CeedInt index;
568         float   sub_max = 0.0, current_max = 0.0;
569         float  *d_array_start;
570 
571         for (CeedInt i = 0; i < num_calls; i++) {
572           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
573           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
574           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
575 
576           CeedCallCublas(ceed, cublasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &index));
577           CeedCallCuda(ceed, cudaMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
578           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
579         }
580         *norm = current_max;
581 #endif
582       } else {
583 #if CUDA_VERSION >= 12000
584         int64_t    index;
585         CeedScalar norm_no_abs;
586 
587         CeedCallCublas(ceed, cublasIdamax_64(handle, (int64_t)length, (double *)d_array, 1, &index));
588         CeedCallCuda(ceed, cudaMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
589         *norm = fabs(norm_no_abs);
590 #else
591         CeedInt index;
592         double  sub_max = 0.0, current_max = 0.0;
593         double *d_array_start;
594 
595         for (CeedInt i = 0; i < num_calls; i++) {
596           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
597           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
598           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
599 
600           CeedCallCublas(ceed, cublasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &index));
601           CeedCallCuda(ceed, cudaMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), cudaMemcpyDeviceToHost));
602           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
603         }
604         *norm = current_max;
605 #endif
606       }
607       break;
608     }
609   }
610   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
611   CeedCallBackend(CeedDestroy(&ceed));
612   return CEED_ERROR_SUCCESS;
613 }
614 
615 //------------------------------------------------------------------------------
616 // Take reciprocal of a vector on host
617 //------------------------------------------------------------------------------
618 static int CeedHostReciprocal_Cuda(CeedScalar *h_array, CeedSize length) {
619   for (CeedSize i = 0; i < length; i++) {
620     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
621   }
622   return CEED_ERROR_SUCCESS;
623 }
624 
625 //------------------------------------------------------------------------------
626 // Take reciprocal of a vector on device (impl in .cu file)
627 //------------------------------------------------------------------------------
628 int CeedDeviceReciprocal_Cuda(CeedScalar *d_array, CeedSize length);
629 
630 //------------------------------------------------------------------------------
631 // Take reciprocal of a vector
632 //------------------------------------------------------------------------------
633 static int CeedVectorReciprocal_Cuda(CeedVector vec) {
634   CeedSize         length;
635   CeedVector_Cuda *impl;
636 
637   CeedCallBackend(CeedVectorGetData(vec, &impl));
638   CeedCallBackend(CeedVectorGetLength(vec, &length));
639   // Set value for synced device/host array
640   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Cuda(impl->d_array, length));
641   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Cuda(impl->h_array, length));
642   return CEED_ERROR_SUCCESS;
643 }
644 
645 //------------------------------------------------------------------------------
646 // Compute x = alpha x on the host
647 //------------------------------------------------------------------------------
648 static int CeedHostScale_Cuda(CeedScalar *x_array, CeedScalar alpha, CeedSize length) {
649   for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha;
650   return CEED_ERROR_SUCCESS;
651 }
652 
653 //------------------------------------------------------------------------------
654 // Compute x = alpha x on device (impl in .cu file)
655 //------------------------------------------------------------------------------
656 int CeedDeviceScale_Cuda(CeedScalar *x_array, CeedScalar alpha, CeedSize length);
657 
658 //------------------------------------------------------------------------------
659 // Compute x = alpha x
660 //------------------------------------------------------------------------------
661 static int CeedVectorScale_Cuda(CeedVector x, CeedScalar alpha) {
662   CeedSize         length;
663   CeedVector_Cuda *x_impl;
664 
665   CeedCallBackend(CeedVectorGetData(x, &x_impl));
666   CeedCallBackend(CeedVectorGetLength(x, &length));
667   // Set value for synced device/host array
668   if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Cuda(x_impl->d_array, alpha, length));
669   if (x_impl->h_array) CeedCallBackend(CeedHostScale_Cuda(x_impl->h_array, alpha, length));
670   return CEED_ERROR_SUCCESS;
671 }
672 
673 //------------------------------------------------------------------------------
674 // Compute y = alpha x + y on the host
675 //------------------------------------------------------------------------------
676 static int CeedHostAXPY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) {
677   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
678   return CEED_ERROR_SUCCESS;
679 }
680 
681 //------------------------------------------------------------------------------
682 // Compute y = alpha x + y on device (impl in .cu file)
683 //------------------------------------------------------------------------------
684 int CeedDeviceAXPY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length);
685 
686 //------------------------------------------------------------------------------
687 // Compute y = alpha x + y
688 //------------------------------------------------------------------------------
689 static int CeedVectorAXPY_Cuda(CeedVector y, CeedScalar alpha, CeedVector x) {
690   CeedSize         length;
691   CeedVector_Cuda *y_impl, *x_impl;
692 
693   CeedCallBackend(CeedVectorGetData(y, &y_impl));
694   CeedCallBackend(CeedVectorGetData(x, &x_impl));
695   CeedCallBackend(CeedVectorGetLength(y, &length));
696   // Set value for synced device/host array
697   if (y_impl->d_array) {
698     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
699     CeedCallBackend(CeedDeviceAXPY_Cuda(y_impl->d_array, alpha, x_impl->d_array, length));
700   }
701   if (y_impl->h_array) {
702     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
703     CeedCallBackend(CeedHostAXPY_Cuda(y_impl->h_array, alpha, x_impl->h_array, length));
704   }
705   return CEED_ERROR_SUCCESS;
706 }
707 
708 //------------------------------------------------------------------------------
709 // Compute y = alpha x + beta y on the host
710 //------------------------------------------------------------------------------
711 static int CeedHostAXPBY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) {
712   for (CeedSize i = 0; i < length; i++) y_array[i] = alpha * x_array[i] + beta * y_array[i];
713   return CEED_ERROR_SUCCESS;
714 }
715 
716 //------------------------------------------------------------------------------
717 // Compute y = alpha x + beta y on device (impl in .cu file)
718 //------------------------------------------------------------------------------
719 int CeedDeviceAXPBY_Cuda(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length);
720 
721 //------------------------------------------------------------------------------
722 // Compute y = alpha x + beta y
723 //------------------------------------------------------------------------------
724 static int CeedVectorAXPBY_Cuda(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
725   CeedSize         length;
726   CeedVector_Cuda *y_impl, *x_impl;
727 
728   CeedCallBackend(CeedVectorGetData(y, &y_impl));
729   CeedCallBackend(CeedVectorGetData(x, &x_impl));
730   CeedCallBackend(CeedVectorGetLength(y, &length));
731   // Set value for synced device/host array
732   if (y_impl->d_array) {
733     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
734     CeedCallBackend(CeedDeviceAXPBY_Cuda(y_impl->d_array, alpha, beta, x_impl->d_array, length));
735   }
736   if (y_impl->h_array) {
737     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
738     CeedCallBackend(CeedHostAXPBY_Cuda(y_impl->h_array, alpha, beta, x_impl->h_array, length));
739   }
740   return CEED_ERROR_SUCCESS;
741 }
742 
743 //------------------------------------------------------------------------------
744 // Compute the pointwise multiplication w = x .* y on the host
745 //------------------------------------------------------------------------------
746 static int CeedHostPointwiseMult_Cuda(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) {
747   for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
748   return CEED_ERROR_SUCCESS;
749 }
750 
751 //------------------------------------------------------------------------------
752 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
753 //------------------------------------------------------------------------------
754 int CeedDevicePointwiseMult_Cuda(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length);
755 
756 //------------------------------------------------------------------------------
757 // Compute the pointwise multiplication w = x .* y
758 //------------------------------------------------------------------------------
759 static int CeedVectorPointwiseMult_Cuda(CeedVector w, CeedVector x, CeedVector y) {
760   CeedSize         length;
761   CeedVector_Cuda *w_impl, *x_impl, *y_impl;
762 
763   CeedCallBackend(CeedVectorGetData(w, &w_impl));
764   CeedCallBackend(CeedVectorGetData(x, &x_impl));
765   CeedCallBackend(CeedVectorGetData(y, &y_impl));
766   CeedCallBackend(CeedVectorGetLength(w, &length));
767   // Set value for synced device/host array
768   if (!w_impl->d_array && !w_impl->h_array) {
769     CeedCallBackend(CeedVectorSetValue(w, 0.0));
770   }
771   if (w_impl->d_array) {
772     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
773     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
774     CeedCallBackend(CeedDevicePointwiseMult_Cuda(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
775   }
776   if (w_impl->h_array) {
777     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
778     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
779     CeedCallBackend(CeedHostPointwiseMult_Cuda(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
780   }
781   return CEED_ERROR_SUCCESS;
782 }
783 
784 //------------------------------------------------------------------------------
785 // Destroy the vector
786 //------------------------------------------------------------------------------
787 static int CeedVectorDestroy_Cuda(const CeedVector vec) {
788   CeedVector_Cuda *impl;
789 
790   CeedCallBackend(CeedVectorGetData(vec, &impl));
791   CeedCallCuda(CeedVectorReturnCeed(vec), cudaFree(impl->d_array_owned));
792   CeedCallBackend(CeedFree(&impl->h_array_owned));
793   CeedCallBackend(CeedFree(&impl));
794   return CEED_ERROR_SUCCESS;
795 }
796 
797 //------------------------------------------------------------------------------
798 // Create a vector of the specified length (does not allocate memory)
799 //------------------------------------------------------------------------------
800 int CeedVectorCreate_Cuda(CeedSize n, CeedVector vec) {
801   CeedVector_Cuda *impl;
802   Ceed             ceed;
803 
804   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
805   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Cuda));
806   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Cuda));
807   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Cuda));
808   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Cuda));
809   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "CopyStrided", CeedVectorCopyStrided_Cuda));
810   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", CeedVectorSetValue_Cuda));
811   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValueStrided", CeedVectorSetValueStrided_Cuda));
812   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Cuda));
813   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Cuda));
814   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Cuda));
815   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Cuda));
816   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Cuda));
817   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Cuda));
818   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", CeedVectorScale_Cuda));
819   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", CeedVectorAXPY_Cuda));
820   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", CeedVectorAXPBY_Cuda));
821   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Cuda));
822   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Cuda));
823   CeedCallBackend(CeedDestroy(&ceed));
824   CeedCallBackend(CeedCalloc(1, &impl));
825   CeedCallBackend(CeedVectorSetData(vec, impl));
826   return CEED_ERROR_SUCCESS;
827 }
828 
829 //------------------------------------------------------------------------------
830