xref: /libCEED/backends/hip-ref/ceed-hip-ref-vector.c (revision 4ed2b27715919084fc6a9c0e6a05a1e8adc518b8)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <math.h>
11 #include <stdbool.h>
12 #include <string.h>
13 #include <hip/hip_runtime.h>
14 
15 #include "../hip/ceed-hip-common.h"
16 #include "ceed-hip-ref.h"
17 
18 //------------------------------------------------------------------------------
19 // Check if host/device sync is needed
20 //------------------------------------------------------------------------------
21 static inline int CeedVectorNeedSync_Hip(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
22   CeedVector_Hip *impl;
23   bool            has_valid_array = false;
24 
25   CeedCallBackend(CeedVectorGetData(vec, &impl));
26   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
27   switch (mem_type) {
28     case CEED_MEM_HOST:
29       *need_sync = has_valid_array && !impl->h_array;
30       break;
31     case CEED_MEM_DEVICE:
32       *need_sync = has_valid_array && !impl->d_array;
33       break;
34   }
35   return CEED_ERROR_SUCCESS;
36 }
37 
38 //------------------------------------------------------------------------------
39 // Sync host to device
40 //------------------------------------------------------------------------------
41 static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) {
42   CeedSize        length;
43   size_t          bytes;
44   CeedVector_Hip *impl;
45 
46   CeedCallBackend(CeedVectorGetData(vec, &impl));
47 
48   CeedCheck(impl->h_array, CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "No valid host data to sync to device");
49 
50   CeedCallBackend(CeedVectorGetLength(vec, &length));
51   bytes = length * sizeof(CeedScalar);
52   if (impl->d_array_borrowed) {
53     impl->d_array = impl->d_array_borrowed;
54   } else if (impl->d_array_owned) {
55     impl->d_array = impl->d_array_owned;
56   } else {
57     CeedCallHip(CeedVectorReturnCeed(vec), hipMalloc((void **)&impl->d_array_owned, bytes));
58     impl->d_array = impl->d_array_owned;
59   }
60   CeedCallHip(CeedVectorReturnCeed(vec), hipMemcpy(impl->d_array, impl->h_array, bytes, hipMemcpyHostToDevice));
61   return CEED_ERROR_SUCCESS;
62 }
63 
64 //------------------------------------------------------------------------------
65 // Sync device to host
66 //------------------------------------------------------------------------------
67 static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
68   CeedSize        length;
69   size_t          bytes;
70   CeedVector_Hip *impl;
71 
72   CeedCallBackend(CeedVectorGetData(vec, &impl));
73 
74   CeedCheck(impl->d_array, CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "No valid device data to sync to host");
75 
76   if (impl->h_array_borrowed) {
77     impl->h_array = impl->h_array_borrowed;
78   } else if (impl->h_array_owned) {
79     impl->h_array = impl->h_array_owned;
80   } else {
81     CeedSize length;
82 
83     CeedCallBackend(CeedVectorGetLength(vec, &length));
84     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
85     impl->h_array = impl->h_array_owned;
86   }
87 
88   CeedCallBackend(CeedVectorGetLength(vec, &length));
89   bytes = length * sizeof(CeedScalar);
90   CeedCallHip(CeedVectorReturnCeed(vec), hipMemcpy(impl->h_array, impl->d_array, bytes, hipMemcpyDeviceToHost));
91   return CEED_ERROR_SUCCESS;
92 }
93 
94 //------------------------------------------------------------------------------
95 // Sync arrays
96 //------------------------------------------------------------------------------
97 static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) {
98   bool need_sync = false;
99 
100   // Check whether device/host sync is needed
101   CeedCallBackend(CeedVectorNeedSync_Hip(vec, mem_type, &need_sync));
102   if (!need_sync) return CEED_ERROR_SUCCESS;
103 
104   switch (mem_type) {
105     case CEED_MEM_HOST:
106       return CeedVectorSyncD2H_Hip(vec);
107     case CEED_MEM_DEVICE:
108       return CeedVectorSyncH2D_Hip(vec);
109   }
110   return CEED_ERROR_UNSUPPORTED;
111 }
112 
113 //------------------------------------------------------------------------------
114 // Set all pointers as invalid
115 //------------------------------------------------------------------------------
116 static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) {
117   CeedVector_Hip *impl;
118 
119   CeedCallBackend(CeedVectorGetData(vec, &impl));
120   impl->h_array = NULL;
121   impl->d_array = NULL;
122   return CEED_ERROR_SUCCESS;
123 }
124 
125 //------------------------------------------------------------------------------
126 // Check if CeedVector has any valid pointer
127 //------------------------------------------------------------------------------
128 static inline int CeedVectorHasValidArray_Hip(const CeedVector vec, bool *has_valid_array) {
129   CeedVector_Hip *impl;
130 
131   CeedCallBackend(CeedVectorGetData(vec, &impl));
132   *has_valid_array = impl->h_array || impl->d_array;
133   return CEED_ERROR_SUCCESS;
134 }
135 
136 //------------------------------------------------------------------------------
137 // Check if has array of given type
138 //------------------------------------------------------------------------------
139 static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
140   CeedVector_Hip *impl;
141 
142   CeedCallBackend(CeedVectorGetData(vec, &impl));
143   switch (mem_type) {
144     case CEED_MEM_HOST:
145       *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned;
146       break;
147     case CEED_MEM_DEVICE:
148       *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned;
149       break;
150   }
151   return CEED_ERROR_SUCCESS;
152 }
153 
154 //------------------------------------------------------------------------------
155 // Check if has borrowed array of given type
156 //------------------------------------------------------------------------------
157 static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
158   CeedVector_Hip *impl;
159 
160   CeedCallBackend(CeedVectorGetData(vec, &impl));
161   switch (mem_type) {
162     case CEED_MEM_HOST:
163       *has_borrowed_array_of_type = impl->h_array_borrowed;
164       break;
165     case CEED_MEM_DEVICE:
166       *has_borrowed_array_of_type = impl->d_array_borrowed;
167       break;
168   }
169   return CEED_ERROR_SUCCESS;
170 }
171 
172 //------------------------------------------------------------------------------
173 // Set array from host
174 //------------------------------------------------------------------------------
175 static int CeedVectorSetArrayHost_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
176   CeedSize        length;
177   CeedVector_Hip *impl;
178 
179   CeedCallBackend(CeedVectorGetData(vec, &impl));
180   CeedCallBackend(CeedVectorGetLength(vec, &length));
181 
182   CeedCallBackend(CeedSetHostCeedScalarArray(array, copy_mode, length, (const CeedScalar **)&impl->h_array_owned,
183                                              (const CeedScalar **)&impl->h_array_borrowed, (const CeedScalar **)&impl->h_array));
184   return CEED_ERROR_SUCCESS;
185 }
186 
187 //------------------------------------------------------------------------------
188 // Set array from device
189 //------------------------------------------------------------------------------
190 static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
191   CeedSize        length;
192   Ceed            ceed;
193   CeedVector_Hip *impl;
194 
195   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
196   CeedCallBackend(CeedVectorGetData(vec, &impl));
197   CeedCallBackend(CeedVectorGetLength(vec, &length));
198 
199   CeedCallBackend(CeedSetDeviceCeedScalarArray_Hip(ceed, array, copy_mode, length, (const CeedScalar **)&impl->d_array_owned,
200                                                    (const CeedScalar **)&impl->d_array_borrowed, (const CeedScalar **)&impl->d_array));
201   CeedCallBackend(CeedDestroy(&ceed));
202   return CEED_ERROR_SUCCESS;
203 }
204 
205 //------------------------------------------------------------------------------
206 // Set the array used by a vector,
207 //   freeing any previously allocated array if applicable
208 //------------------------------------------------------------------------------
209 static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
210   CeedVector_Hip *impl;
211 
212   CeedCallBackend(CeedVectorGetData(vec, &impl));
213   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
214   switch (mem_type) {
215     case CEED_MEM_HOST:
216       return CeedVectorSetArrayHost_Hip(vec, copy_mode, array);
217     case CEED_MEM_DEVICE:
218       return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array);
219   }
220   return CEED_ERROR_UNSUPPORTED;
221 }
222 
223 //------------------------------------------------------------------------------
224 // Copy host array to value strided
225 //------------------------------------------------------------------------------
226 static int CeedHostCopyStrided_Hip(CeedScalar *h_array, CeedSize start, CeedSize step, CeedSize length, CeedScalar *h_copy_array) {
227   for (CeedSize i = start; i < length; i += step) h_copy_array[i] = h_array[i];
228   return CEED_ERROR_SUCCESS;
229 }
230 
231 //------------------------------------------------------------------------------
232 // Copy device array to value strided (impl in .hip.cpp file)
233 //------------------------------------------------------------------------------
234 int CeedDeviceCopyStrided_Hip(CeedScalar *d_array, CeedSize start, CeedSize step, CeedSize length, CeedScalar *d_copy_array);
235 
236 //------------------------------------------------------------------------------
237 // Copy a vector to a value strided
238 //------------------------------------------------------------------------------
239 static int CeedVectorCopyStrided_Hip(CeedVector vec, CeedSize start, CeedSize step, CeedVector vec_copy) {
240   CeedSize        length;
241   CeedVector_Hip *impl;
242 
243   CeedCallBackend(CeedVectorGetData(vec, &impl));
244   {
245     CeedSize length_vec, length_copy;
246 
247     CeedCallBackend(CeedVectorGetLength(vec, &length_vec));
248     CeedCallBackend(CeedVectorGetLength(vec_copy, &length_copy));
249     length = length_vec < length_copy ? length_vec : length_copy;
250   }
251   // Set value for synced device/host array
252   if (impl->d_array) {
253     CeedScalar *copy_array;
254 
255     CeedCallBackend(CeedVectorGetArray(vec_copy, CEED_MEM_DEVICE, &copy_array));
256 #if (HIP_VERSION >= 60000000)
257     hipblasHandle_t handle;
258     Ceed            ceed;
259 
260     CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
261     CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
262 #if defined(CEED_SCALAR_IS_FP32)
263     CeedCallHipblas(ceed, hipblasScopy_64(handle, (int64_t)length, impl->d_array + start, (int64_t)step, copy_array + start, (int64_t)step));
264 #else  /* CEED_SCALAR */
265     CeedCallHipblas(ceed, hipblasDcopy_64(handle, (int64_t)length, impl->d_array + start, (int64_t)step, copy_array + start, (int64_t)step));
266 #endif /* CEED_SCALAR */
267 #else  /* HIP_VERSION */
268     CeedCallBackend(CeedDeviceCopyStrided_Hip(impl->d_array, start, step, length, copy_array));
269 #endif /* HIP_VERSION */
270     CeedCallBackend(CeedVectorRestoreArray(vec_copy, &copy_array));
271     impl->h_array = NULL;
272     CeedCallBackend(CeedDestroy(&ceed));
273   } else if (impl->h_array) {
274     CeedScalar *copy_array;
275 
276     CeedCallBackend(CeedVectorGetArray(vec_copy, CEED_MEM_HOST, &copy_array));
277     CeedCallBackend(CeedHostCopyStrided_Hip(impl->h_array, start, step, length, copy_array));
278     CeedCallBackend(CeedVectorRestoreArray(vec_copy, &copy_array));
279     impl->d_array = NULL;
280   } else {
281     return CeedError(CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "CeedVector must have valid data set");
282   }
283   return CEED_ERROR_SUCCESS;
284 }
285 
286 //------------------------------------------------------------------------------
287 // Set host array to value
288 //------------------------------------------------------------------------------
289 static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedSize length, CeedScalar val) {
290   for (CeedSize i = 0; i < length; i++) h_array[i] = val;
291   return CEED_ERROR_SUCCESS;
292 }
293 
294 //------------------------------------------------------------------------------
295 // Set device array to value (impl in .hip file)
296 //------------------------------------------------------------------------------
297 int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val);
298 
299 //------------------------------------------------------------------------------
300 // Set a vector to a value
301 //------------------------------------------------------------------------------
302 static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
303   CeedSize        length;
304   CeedVector_Hip *impl;
305 
306   CeedCallBackend(CeedVectorGetData(vec, &impl));
307   CeedCallBackend(CeedVectorGetLength(vec, &length));
308   // Set value for synced device/host array
309   if (!impl->d_array && !impl->h_array) {
310     if (impl->d_array_borrowed) {
311       impl->d_array = impl->d_array_borrowed;
312     } else if (impl->h_array_borrowed) {
313       impl->h_array = impl->h_array_borrowed;
314     } else if (impl->d_array_owned) {
315       impl->d_array = impl->d_array_owned;
316     } else if (impl->h_array_owned) {
317       impl->h_array = impl->h_array_owned;
318     } else {
319       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
320     }
321   }
322   if (impl->d_array) {
323     if (val == 0) {
324       CeedCallHip(CeedVectorReturnCeed(vec), hipMemset(impl->d_array, 0, length * sizeof(CeedScalar)));
325     } else {
326       CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val));
327     }
328     impl->h_array = NULL;
329   } else if (impl->h_array) {
330     CeedCallBackend(CeedHostSetValue_Hip(impl->h_array, length, val));
331     impl->d_array = NULL;
332   }
333   return CEED_ERROR_SUCCESS;
334 }
335 
336 //------------------------------------------------------------------------------
337 // Set host array to value strided
338 //------------------------------------------------------------------------------
339 static int CeedHostSetValueStrided_Hip(CeedScalar *h_array, CeedSize start, CeedSize step, CeedSize length, CeedScalar val) {
340   for (CeedSize i = start; i < length; i += step) h_array[i] = val;
341   return CEED_ERROR_SUCCESS;
342 }
343 
344 //------------------------------------------------------------------------------
345 // Set device array to value strided (impl in .hip.cpp file)
346 //------------------------------------------------------------------------------
347 int CeedDeviceSetValueStrided_Hip(CeedScalar *d_array, CeedSize start, CeedSize step, CeedSize length, CeedScalar val);
348 
349 //------------------------------------------------------------------------------
350 // Set a vector to a value strided
351 //------------------------------------------------------------------------------
352 static int CeedVectorSetValueStrided_Hip(CeedVector vec, CeedSize start, CeedSize step, CeedScalar val) {
353   CeedSize        length;
354   CeedVector_Hip *impl;
355 
356   CeedCallBackend(CeedVectorGetData(vec, &impl));
357   CeedCallBackend(CeedVectorGetLength(vec, &length));
358   // Set value for synced device/host array
359   if (impl->d_array) {
360     CeedCallBackend(CeedDeviceSetValueStrided_Hip(impl->d_array, start, step, length, val));
361     impl->h_array = NULL;
362   } else if (impl->h_array) {
363     CeedCallBackend(CeedHostSetValueStrided_Hip(impl->h_array, start, step, length, val));
364     impl->d_array = NULL;
365   } else {
366     return CeedError(CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "CeedVector must have valid data set");
367   }
368   return CEED_ERROR_SUCCESS;
369 }
370 
371 //------------------------------------------------------------------------------
372 // Vector Take Array
373 //------------------------------------------------------------------------------
374 static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
375   CeedVector_Hip *impl;
376 
377   CeedCallBackend(CeedVectorGetData(vec, &impl));
378 
379   // Sync array to requested mem_type
380   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
381 
382   // Update pointer
383   switch (mem_type) {
384     case CEED_MEM_HOST:
385       (*array)               = impl->h_array_borrowed;
386       impl->h_array_borrowed = NULL;
387       impl->h_array          = NULL;
388       break;
389     case CEED_MEM_DEVICE:
390       (*array)               = impl->d_array_borrowed;
391       impl->d_array_borrowed = NULL;
392       impl->d_array          = NULL;
393       break;
394   }
395   return CEED_ERROR_SUCCESS;
396 }
397 
398 //------------------------------------------------------------------------------
399 // Core logic for array syncronization for GetArray.
400 //   If a different memory type is most up to date, this will perform a copy
401 //------------------------------------------------------------------------------
402 static int CeedVectorGetArrayCore_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
403   CeedVector_Hip *impl;
404 
405   CeedCallBackend(CeedVectorGetData(vec, &impl));
406 
407   // Sync array to requested mem_type
408   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
409 
410   // Update pointer
411   switch (mem_type) {
412     case CEED_MEM_HOST:
413       *array = impl->h_array;
414       break;
415     case CEED_MEM_DEVICE:
416       *array = impl->d_array;
417       break;
418   }
419   return CEED_ERROR_SUCCESS;
420 }
421 
422 //------------------------------------------------------------------------------
423 // Get read-only access to a vector via the specified mem_type
424 //------------------------------------------------------------------------------
425 static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
426   return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array);
427 }
428 
429 //------------------------------------------------------------------------------
430 // Get read/write access to a vector via the specified mem_type
431 //------------------------------------------------------------------------------
432 static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
433   CeedVector_Hip *impl;
434 
435   CeedCallBackend(CeedVectorGetData(vec, &impl));
436   CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, mem_type, array));
437   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
438   switch (mem_type) {
439     case CEED_MEM_HOST:
440       impl->h_array = *array;
441       break;
442     case CEED_MEM_DEVICE:
443       impl->d_array = *array;
444       break;
445   }
446   return CEED_ERROR_SUCCESS;
447 }
448 
449 //------------------------------------------------------------------------------
450 // Get write access to a vector via the specified mem_type
451 //------------------------------------------------------------------------------
452 static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
453   bool            has_array_of_type = true;
454   CeedVector_Hip *impl;
455 
456   CeedCallBackend(CeedVectorGetData(vec, &impl));
457   CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type));
458   if (!has_array_of_type) {
459     // Allocate if array is not yet allocated
460     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
461   } else {
462     // Select dirty array
463     switch (mem_type) {
464       case CEED_MEM_HOST:
465         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
466         else impl->h_array = impl->h_array_owned;
467         break;
468       case CEED_MEM_DEVICE:
469         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
470         else impl->d_array = impl->d_array_owned;
471     }
472   }
473   return CeedVectorGetArray_Hip(vec, mem_type, array);
474 }
475 
476 //------------------------------------------------------------------------------
477 // Get the norm of a CeedVector
478 //------------------------------------------------------------------------------
479 static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *norm) {
480   Ceed     ceed;
481   CeedSize length;
482 #if (HIP_VERSION < 60000000)
483   CeedSize num_calls;
484 #endif /* HIP_VERSION */
485   const CeedScalar *d_array;
486   CeedVector_Hip   *impl;
487   hipblasHandle_t   handle;
488 
489   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
490   CeedCallBackend(CeedVectorGetData(vec, &impl));
491   CeedCallBackend(CeedVectorGetLength(vec, &length));
492   CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
493 
494 #if (HIP_VERSION < 60000000)
495   // With ROCm 6, we can use the 64-bit integer interface. Prior to that,
496   // we need to check if the vector is too long to handle with int32,
497   // and if so, divide it into subsections for repeated hipBLAS calls.
498   num_calls = length / INT_MAX;
499   if (length % INT_MAX > 0) num_calls += 1;
500 #endif /* HIP_VERSION */
501 
502   // Compute norm
503   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
504   switch (type) {
505     case CEED_NORM_1: {
506       *norm = 0.0;
507 #if defined(CEED_SCALAR_IS_FP32)
508 #if (HIP_VERSION >= 60000000)  // We have ROCm 6, and can use 64-bit integers
509       CeedCallHipblas(ceed, hipblasSasum_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
510 #else  /* HIP_VERSION */
511       float  sub_norm = 0.0;
512       float *d_array_start;
513 
514       for (CeedInt i = 0; i < num_calls; i++) {
515         d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
516         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
517         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
518 
519         CeedCallHipblas(ceed, cublasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
520         *norm += sub_norm;
521       }
522 #endif /* HIP_VERSION */
523 #else  /* CEED_SCALAR */
524 #if (HIP_VERSION >= 60000000)
525       CeedCallHipblas(ceed, hipblasDasum_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
526 #else  /* HIP_VERSION */
527       double  sub_norm = 0.0;
528       double *d_array_start;
529 
530       for (CeedInt i = 0; i < num_calls; i++) {
531         d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
532         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
533         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
534 
535         CeedCallHipblas(ceed, hipblasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
536         *norm += sub_norm;
537       }
538 #endif /* HIP_VERSION */
539 #endif /* CEED_SCALAR */
540       break;
541     }
542     case CEED_NORM_2: {
543 #if defined(CEED_SCALAR_IS_FP32)
544 #if (HIP_VERSION >= 60000000)
545       CeedCallHipblas(ceed, hipblasSnrm2_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
546 #else  /* CUDA_VERSION */
547       float  sub_norm = 0.0, norm_sum = 0.0;
548       float *d_array_start;
549 
550       for (CeedInt i = 0; i < num_calls; i++) {
551         d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
552         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
553         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
554 
555         CeedCallHipblas(ceed, hipblasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
556         norm_sum += sub_norm * sub_norm;
557       }
558       *norm = sqrt(norm_sum);
559 #endif /* HIP_VERSION */
560 #else  /* CEED_SCALAR */
561 #if (HIP_VERSION >= 60000000)
562       CeedCallHipblas(ceed, hipblasDnrm2_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
563 #else  /* CUDA_VERSION */
564       double  sub_norm = 0.0, norm_sum = 0.0;
565       double *d_array_start;
566 
567       for (CeedInt i = 0; i < num_calls; i++) {
568         d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
569         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
570         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
571 
572         CeedCallHipblas(ceed, hipblasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
573         norm_sum += sub_norm * sub_norm;
574       }
575       *norm = sqrt(norm_sum);
576 #endif /* HIP_VERSION */
577 #endif /* CEED_SCALAR */
578       break;
579     }
580     case CEED_NORM_MAX: {
581 #if defined(CEED_SCALAR_IS_FP32)
582 #if (HIP_VERSION >= 60000000)
583       int64_t    index;
584       CeedScalar norm_no_abs;
585 
586       CeedCallHipblas(ceed, hipblasIsamax_64(handle, (int64_t)length, (float *)d_array, 1, &index));
587       CeedCallHip(ceed, hipMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
588       *norm = fabs(norm_no_abs);
589 #else  /* HIP_VERSION */
590       CeedInt index;
591       float   sub_max = 0.0, current_max = 0.0;
592       float  *d_array_start;
593 
594       for (CeedInt i = 0; i < num_calls; i++) {
595         d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
596         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
597         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
598 
599         CeedCallHipblas(ceed, hipblasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &index));
600         CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
601         if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
602       }
603       *norm = current_max;
604 #endif /* HIP_VERSION */
605 #else  /* CEED_SCALAR */
606 #if (HIP_VERSION >= 60000000)
607       int64_t    index;
608       CeedScalar norm_no_abs;
609 
610       CeedCallHipblas(ceed, hipblasIdamax_64(handle, (int64_t)length, (double *)d_array, 1, &index));
611       CeedCallHip(ceed, hipMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
612       *norm = fabs(norm_no_abs);
613 #else  /* HIP_VERSION */
614       CeedInt index;
615       double  sub_max = 0.0, current_max = 0.0;
616       double *d_array_start;
617 
618       for (CeedInt i = 0; i < num_calls; i++) {
619         d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
620         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
621         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
622 
623         CeedCallHipblas(ceed, hipblasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &index));
624         CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
625         if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
626       }
627       *norm = current_max;
628 #endif /* HIP_VERSION */
629 #endif /* CEED_SCALAR */
630       break;
631     }
632   }
633   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
634   CeedCallBackend(CeedDestroy(&ceed));
635   return CEED_ERROR_SUCCESS;
636 }
637 
638 //------------------------------------------------------------------------------
639 // Take reciprocal of a vector on host
640 //------------------------------------------------------------------------------
641 static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedSize length) {
642   for (CeedSize i = 0; i < length; i++) {
643     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
644   }
645   return CEED_ERROR_SUCCESS;
646 }
647 
648 //------------------------------------------------------------------------------
649 // Take reciprocal of a vector on device (impl in .hip.cpp file)
650 //------------------------------------------------------------------------------
651 int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedSize length);
652 
653 //------------------------------------------------------------------------------
654 // Take reciprocal of a vector
655 //------------------------------------------------------------------------------
656 static int CeedVectorReciprocal_Hip(CeedVector vec) {
657   CeedSize        length;
658   CeedVector_Hip *impl;
659 
660   CeedCallBackend(CeedVectorGetData(vec, &impl));
661   CeedCallBackend(CeedVectorGetLength(vec, &length));
662   // Set value for synced device/host array
663   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Hip(impl->d_array, length));
664   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length));
665   return CEED_ERROR_SUCCESS;
666 }
667 
668 //------------------------------------------------------------------------------
669 // Compute x = alpha x on the host
670 //------------------------------------------------------------------------------
671 static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length) {
672   for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha;
673   return CEED_ERROR_SUCCESS;
674 }
675 
676 //------------------------------------------------------------------------------
677 // Compute x = alpha x on device (impl in .hip.cpp file)
678 //------------------------------------------------------------------------------
679 int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length);
680 
681 //------------------------------------------------------------------------------
682 // Compute x = alpha x
683 //------------------------------------------------------------------------------
684 static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
685   CeedSize        length;
686   CeedVector_Hip *impl;
687 
688   CeedCallBackend(CeedVectorGetData(x, &impl));
689   CeedCallBackend(CeedVectorGetLength(x, &length));
690   // Set value for synced device/host array
691   if (impl->d_array) {
692 #if (HIP_VERSION >= 60000000)
693     hipblasHandle_t handle;
694 
695     CeedCallBackend(CeedGetHipblasHandle_Hip(CeedVectorReturnCeed(x), &handle));
696 #if defined(CEED_SCALAR_IS_FP32)
697     CeedCallHipblas(CeedVectorReturnCeed(x), hipblasSscal_64(handle, (int64_t)length, &alpha, impl->d_array, 1));
698 #else  /* CEED_SCALAR */
699     CeedCallHipblas(CeedVectorReturnCeed(x), hipblasDscal_64(handle, (int64_t)length, &alpha, impl->d_array, 1));
700 #endif /* CEED_SCALAR */
701 #else  /* HIP_VERSION */
702     CeedCallBackend(CeedDeviceScale_Hip(impl->d_array, alpha, length));
703 #endif /* HIP_VERSION */
704     impl->h_array = NULL;
705   }
706   if (impl->h_array) {
707     CeedCallBackend(CeedHostScale_Hip(impl->h_array, alpha, length));
708     impl->d_array = NULL;
709   }
710   return CEED_ERROR_SUCCESS;
711 }
712 
713 //------------------------------------------------------------------------------
714 // Compute y = alpha x + y on the host
715 //------------------------------------------------------------------------------
716 static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) {
717   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
718   return CEED_ERROR_SUCCESS;
719 }
720 
721 //------------------------------------------------------------------------------
722 // Compute y = alpha x + y on device (impl in .hip.cpp file)
723 //------------------------------------------------------------------------------
724 int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length);
725 
726 //------------------------------------------------------------------------------
727 // Compute y = alpha x + y
728 //------------------------------------------------------------------------------
729 static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
730   CeedSize        length;
731   CeedVector_Hip *y_impl, *x_impl;
732 
733   CeedCallBackend(CeedVectorGetData(y, &y_impl));
734   CeedCallBackend(CeedVectorGetData(x, &x_impl));
735   CeedCallBackend(CeedVectorGetLength(y, &length));
736   // Set value for synced device/host array
737   if (y_impl->d_array) {
738     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
739 #if (HIP_VERSION >= 60000000)
740     hipblasHandle_t handle;
741 
742     CeedCallBackend(CeedGetHipblasHandle_Hip(CeedVectorReturnCeed(y), &handle));
743 #if defined(CEED_SCALAR_IS_FP32)
744     CeedCallHipblas(CeedVectorReturnCeed(y), hipblasSaxpy_64(handle, (int64_t)length, &alpha, x_impl->d_array, 1, y_impl->d_array, 1));
745 #else  /* CEED_SCALAR */
746     CeedCallHipblas(CeedVectorReturnCeed(y), hipblasDaxpy_64(handle, (int64_t)length, &alpha, x_impl->d_array, 1, y_impl->d_array, 1));
747 #endif /* CEED_SCALAR */
748 #else  /* HIP_VERSION */
749     CeedCallBackend(CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length));
750 #endif /* HIP_VERSION */
751     y_impl->h_array = NULL;
752   } else if (y_impl->h_array) {
753     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
754     CeedCallBackend(CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length));
755     y_impl->d_array = NULL;
756   }
757   return CEED_ERROR_SUCCESS;
758 }
759 
760 //------------------------------------------------------------------------------
761 // Compute y = alpha x + beta y on the host
762 //------------------------------------------------------------------------------
763 static int CeedHostAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) {
764   for (CeedSize i = 0; i < length; i++) y_array[i] = alpha * x_array[i] + beta * y_array[i];
765   return CEED_ERROR_SUCCESS;
766 }
767 
768 //------------------------------------------------------------------------------
769 // Compute y = alpha x + beta y on device (impl in .hip.cpp file)
770 //------------------------------------------------------------------------------
771 int CeedDeviceAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length);
772 
773 //------------------------------------------------------------------------------
774 // Compute y = alpha x + beta y
775 //------------------------------------------------------------------------------
776 static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
777   CeedSize        length;
778   CeedVector_Hip *y_impl, *x_impl;
779 
780   CeedCallBackend(CeedVectorGetData(y, &y_impl));
781   CeedCallBackend(CeedVectorGetData(x, &x_impl));
782   CeedCallBackend(CeedVectorGetLength(y, &length));
783   // Set value for synced device/host array
784   if (y_impl->d_array) {
785     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
786     CeedCallBackend(CeedDeviceAXPBY_Hip(y_impl->d_array, alpha, beta, x_impl->d_array, length));
787   }
788   if (y_impl->h_array) {
789     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
790     CeedCallBackend(CeedHostAXPBY_Hip(y_impl->h_array, alpha, beta, x_impl->h_array, length));
791   }
792   return CEED_ERROR_SUCCESS;
793 }
794 
795 //------------------------------------------------------------------------------
796 // Compute the pointwise multiplication w = x .* y on the host
797 //------------------------------------------------------------------------------
798 static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) {
799   for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
800   return CEED_ERROR_SUCCESS;
801 }
802 
803 //------------------------------------------------------------------------------
804 // Compute the pointwise multiplication w = x .* y on device (impl in .hip.cpp file)
805 //------------------------------------------------------------------------------
806 int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length);
807 
808 //------------------------------------------------------------------------------
809 // Compute the pointwise multiplication w = x .* y
810 //------------------------------------------------------------------------------
811 static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y) {
812   CeedSize        length;
813   CeedVector_Hip *w_impl, *x_impl, *y_impl;
814 
815   CeedCallBackend(CeedVectorGetData(w, &w_impl));
816   CeedCallBackend(CeedVectorGetData(x, &x_impl));
817   CeedCallBackend(CeedVectorGetData(y, &y_impl));
818   CeedCallBackend(CeedVectorGetLength(w, &length));
819 
820   // Set value for synced device/host array
821   if (!w_impl->d_array && !w_impl->h_array) {
822     CeedCallBackend(CeedVectorSetValue(w, 0.0));
823   }
824   if (w_impl->d_array) {
825     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
826     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
827     CeedCallBackend(CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
828   }
829   if (w_impl->h_array) {
830     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
831     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
832     CeedCallBackend(CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
833   }
834   return CEED_ERROR_SUCCESS;
835 }
836 
837 //------------------------------------------------------------------------------
838 // Destroy the vector
839 //------------------------------------------------------------------------------
840 static int CeedVectorDestroy_Hip(const CeedVector vec) {
841   CeedVector_Hip *impl;
842 
843   CeedCallBackend(CeedVectorGetData(vec, &impl));
844   CeedCallHip(CeedVectorReturnCeed(vec), hipFree(impl->d_array_owned));
845   CeedCallBackend(CeedFree(&impl->h_array_owned));
846   CeedCallBackend(CeedFree(&impl));
847   return CEED_ERROR_SUCCESS;
848 }
849 
850 //------------------------------------------------------------------------------
851 // Create a vector of the specified length (does not allocate memory)
852 //------------------------------------------------------------------------------
853 int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
854   CeedVector_Hip *impl;
855   Ceed            ceed;
856 
857   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
858   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Hip));
859   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Hip));
860   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Hip));
861   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Hip));
862   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "CopyStrided", CeedVectorCopyStrided_Hip));
863   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", CeedVectorSetValue_Hip));
864   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValueStrided", CeedVectorSetValueStrided_Hip));
865   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Hip));
866   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Hip));
867   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Hip));
868   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Hip));
869   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Hip));
870   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Hip));
871   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", CeedVectorScale_Hip));
872   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", CeedVectorAXPY_Hip));
873   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", CeedVectorAXPBY_Hip));
874   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Hip));
875   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Hip));
876   CeedCallBackend(CeedDestroy(&ceed));
877   CeedCallBackend(CeedCalloc(1, &impl));
878   CeedCallBackend(CeedVectorSetData(vec, impl));
879   return CEED_ERROR_SUCCESS;
880 }
881 
882 //------------------------------------------------------------------------------
883