xref: /libCEED/backends/hip-ref/ceed-hip-ref-vector.c (revision f52f9f6cf82557c448b4d7d27acc9d71dfcc9d91)
1 // Copyright (c) 2017-2025, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <math.h>
11 #include <stdbool.h>
12 #include <string.h>
13 #include <hip/hip_runtime.h>
14 
15 #include "../hip/ceed-hip-common.h"
16 #include "ceed-hip-ref.h"
17 
18 //------------------------------------------------------------------------------
19 // Check if host/device sync is needed
20 //------------------------------------------------------------------------------
21 static inline int CeedVectorNeedSync_Hip(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
22   CeedVector_Hip *impl;
23   bool            has_valid_array = false;
24 
25   CeedCallBackend(CeedVectorGetData(vec, &impl));
26   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
27   switch (mem_type) {
28     case CEED_MEM_HOST:
29       *need_sync = has_valid_array && !impl->h_array;
30       break;
31     case CEED_MEM_DEVICE:
32       *need_sync = has_valid_array && !impl->d_array;
33       break;
34   }
35   return CEED_ERROR_SUCCESS;
36 }
37 
38 //------------------------------------------------------------------------------
39 // Sync host to device
40 //------------------------------------------------------------------------------
41 static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) {
42   CeedSize        length;
43   size_t          bytes;
44   CeedVector_Hip *impl;
45 
46   CeedCallBackend(CeedVectorGetData(vec, &impl));
47 
48   CeedCheck(impl->h_array, CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "No valid host data to sync to device");
49 
50   CeedCallBackend(CeedVectorGetLength(vec, &length));
51   bytes = length * sizeof(CeedScalar);
52   if (impl->d_array_borrowed) {
53     impl->d_array = impl->d_array_borrowed;
54   } else if (impl->d_array_owned) {
55     impl->d_array = impl->d_array_owned;
56   } else {
57     CeedCallHip(CeedVectorReturnCeed(vec), hipMalloc((void **)&impl->d_array_owned, bytes));
58     impl->d_array = impl->d_array_owned;
59   }
60   CeedCallHip(CeedVectorReturnCeed(vec), hipMemcpy(impl->d_array, impl->h_array, bytes, hipMemcpyHostToDevice));
61   return CEED_ERROR_SUCCESS;
62 }
63 
64 //------------------------------------------------------------------------------
65 // Sync device to host
66 //------------------------------------------------------------------------------
67 static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
68   CeedSize        length;
69   size_t          bytes;
70   CeedVector_Hip *impl;
71 
72   CeedCallBackend(CeedVectorGetData(vec, &impl));
73 
74   CeedCheck(impl->d_array, CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "No valid device data to sync to host");
75 
76   if (impl->h_array_borrowed) {
77     impl->h_array = impl->h_array_borrowed;
78   } else if (impl->h_array_owned) {
79     impl->h_array = impl->h_array_owned;
80   } else {
81     CeedSize length;
82 
83     CeedCallBackend(CeedVectorGetLength(vec, &length));
84     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
85     impl->h_array = impl->h_array_owned;
86   }
87 
88   CeedCallBackend(CeedVectorGetLength(vec, &length));
89   bytes = length * sizeof(CeedScalar);
90   CeedCallHip(CeedVectorReturnCeed(vec), hipMemcpy(impl->h_array, impl->d_array, bytes, hipMemcpyDeviceToHost));
91   return CEED_ERROR_SUCCESS;
92 }
93 
94 //------------------------------------------------------------------------------
95 // Sync arrays
96 //------------------------------------------------------------------------------
97 static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) {
98   bool need_sync = false;
99 
100   // Check whether device/host sync is needed
101   CeedCallBackend(CeedVectorNeedSync_Hip(vec, mem_type, &need_sync));
102   if (!need_sync) return CEED_ERROR_SUCCESS;
103 
104   switch (mem_type) {
105     case CEED_MEM_HOST:
106       return CeedVectorSyncD2H_Hip(vec);
107     case CEED_MEM_DEVICE:
108       return CeedVectorSyncH2D_Hip(vec);
109   }
110   return CEED_ERROR_UNSUPPORTED;
111 }
112 
113 //------------------------------------------------------------------------------
114 // Set all pointers as invalid
115 //------------------------------------------------------------------------------
116 static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) {
117   CeedVector_Hip *impl;
118 
119   CeedCallBackend(CeedVectorGetData(vec, &impl));
120   impl->h_array = NULL;
121   impl->d_array = NULL;
122   return CEED_ERROR_SUCCESS;
123 }
124 
125 //------------------------------------------------------------------------------
126 // Check if CeedVector has any valid pointer
127 //------------------------------------------------------------------------------
128 static inline int CeedVectorHasValidArray_Hip(const CeedVector vec, bool *has_valid_array) {
129   CeedVector_Hip *impl;
130 
131   CeedCallBackend(CeedVectorGetData(vec, &impl));
132   *has_valid_array = impl->h_array || impl->d_array;
133   return CEED_ERROR_SUCCESS;
134 }
135 
136 //------------------------------------------------------------------------------
137 // Check if has array of given type
138 //------------------------------------------------------------------------------
139 static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
140   CeedVector_Hip *impl;
141 
142   CeedCallBackend(CeedVectorGetData(vec, &impl));
143   switch (mem_type) {
144     case CEED_MEM_HOST:
145       *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned;
146       break;
147     case CEED_MEM_DEVICE:
148       *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned;
149       break;
150   }
151   return CEED_ERROR_SUCCESS;
152 }
153 
154 //------------------------------------------------------------------------------
155 // Check if has borrowed array of given type
156 //------------------------------------------------------------------------------
157 static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
158   CeedVector_Hip *impl;
159 
160   CeedCallBackend(CeedVectorGetData(vec, &impl));
161   switch (mem_type) {
162     case CEED_MEM_HOST:
163       *has_borrowed_array_of_type = impl->h_array_borrowed;
164       break;
165     case CEED_MEM_DEVICE:
166       *has_borrowed_array_of_type = impl->d_array_borrowed;
167       break;
168   }
169   return CEED_ERROR_SUCCESS;
170 }
171 
172 //------------------------------------------------------------------------------
173 // Set array from host
174 //------------------------------------------------------------------------------
175 static int CeedVectorSetArrayHost_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
176   CeedSize        length;
177   CeedVector_Hip *impl;
178 
179   CeedCallBackend(CeedVectorGetData(vec, &impl));
180   CeedCallBackend(CeedVectorGetLength(vec, &length));
181 
182   CeedCallBackend(CeedSetHostCeedScalarArray(array, copy_mode, length, (const CeedScalar **)&impl->h_array_owned,
183                                              (const CeedScalar **)&impl->h_array_borrowed, (const CeedScalar **)&impl->h_array));
184   return CEED_ERROR_SUCCESS;
185 }
186 
187 //------------------------------------------------------------------------------
188 // Set array from device
189 //------------------------------------------------------------------------------
190 static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
191   CeedSize        length;
192   Ceed            ceed;
193   CeedVector_Hip *impl;
194 
195   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
196   CeedCallBackend(CeedVectorGetData(vec, &impl));
197   CeedCallBackend(CeedVectorGetLength(vec, &length));
198 
199   CeedCallBackend(CeedSetDeviceCeedScalarArray_Hip(ceed, array, copy_mode, length, (const CeedScalar **)&impl->d_array_owned,
200                                                    (const CeedScalar **)&impl->d_array_borrowed, (const CeedScalar **)&impl->d_array));
201   CeedCallBackend(CeedDestroy(&ceed));
202   return CEED_ERROR_SUCCESS;
203 }
204 
205 //------------------------------------------------------------------------------
206 // Set the array used by a vector,
207 //   freeing any previously allocated array if applicable
208 //------------------------------------------------------------------------------
209 static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
210   CeedVector_Hip *impl;
211 
212   CeedCallBackend(CeedVectorGetData(vec, &impl));
213   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
214   switch (mem_type) {
215     case CEED_MEM_HOST:
216       return CeedVectorSetArrayHost_Hip(vec, copy_mode, array);
217     case CEED_MEM_DEVICE:
218       return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array);
219   }
220   return CEED_ERROR_UNSUPPORTED;
221 }
222 
223 //------------------------------------------------------------------------------
224 // Copy host array to value strided
225 //------------------------------------------------------------------------------
226 static int CeedHostCopyStrided_Hip(CeedScalar *h_array, CeedSize start, CeedSize step, CeedSize length, CeedScalar *h_copy_array) {
227   for (CeedSize i = start; i < length; i += step) h_copy_array[i] = h_array[i];
228   return CEED_ERROR_SUCCESS;
229 }
230 
231 //------------------------------------------------------------------------------
232 // Copy device array to value strided (impl in .hip.cpp file)
233 //------------------------------------------------------------------------------
234 int CeedDeviceCopyStrided_Hip(CeedScalar *d_array, CeedSize start, CeedSize step, CeedSize length, CeedScalar *d_copy_array);
235 
236 //------------------------------------------------------------------------------
237 // Copy a vector to a value strided
238 //------------------------------------------------------------------------------
239 static int CeedVectorCopyStrided_Hip(CeedVector vec, CeedSize start, CeedSize step, CeedVector vec_copy) {
240   CeedSize        length;
241   CeedVector_Hip *impl;
242 
243   CeedCallBackend(CeedVectorGetData(vec, &impl));
244   {
245     CeedSize length_vec, length_copy;
246 
247     CeedCallBackend(CeedVectorGetLength(vec, &length_vec));
248     CeedCallBackend(CeedVectorGetLength(vec_copy, &length_copy));
249     length = length_vec < length_copy ? length_vec : length_copy;
250   }
251   // Set value for synced device/host array
252   if (impl->d_array) {
253     CeedScalar *copy_array;
254 
255     CeedCallBackend(CeedVectorGetArray(vec_copy, CEED_MEM_DEVICE, &copy_array));
256 #if (HIP_VERSION >= 60000000)
257     hipblasHandle_t handle;
258     Ceed            ceed;
259 
260     CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
261     CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
262 #if defined(CEED_SCALAR_IS_FP32)
263     CeedCallHipblas(ceed, hipblasScopy_64(handle, (int64_t)length, impl->d_array + start, (int64_t)step, copy_array + start, (int64_t)step));
264 #else  /* CEED_SCALAR */
265     CeedCallHipblas(ceed, hipblasDcopy_64(handle, (int64_t)length, impl->d_array + start, (int64_t)step, copy_array + start, (int64_t)step));
266 #endif /* CEED_SCALAR */
267 #else  /* HIP_VERSION */
268     CeedCallBackend(CeedDeviceCopyStrided_Hip(impl->d_array, start, step, length, copy_array));
269 #endif /* HIP_VERSION */
270     CeedCallBackend(CeedVectorRestoreArray(vec_copy, &copy_array));
271     impl->h_array = NULL;
272     CeedCallBackend(CeedDestroy(&ceed));
273   } else if (impl->h_array) {
274     CeedScalar *copy_array;
275 
276     CeedCallBackend(CeedVectorGetArray(vec_copy, CEED_MEM_HOST, &copy_array));
277     CeedCallBackend(CeedHostCopyStrided_Hip(impl->h_array, start, step, length, copy_array));
278     CeedCallBackend(CeedVectorRestoreArray(vec_copy, &copy_array));
279     impl->d_array = NULL;
280   } else {
281     return CeedError(CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "CeedVector must have valid data set");
282   }
283   return CEED_ERROR_SUCCESS;
284 }
285 
286 //------------------------------------------------------------------------------
287 // Set host array to value
288 //------------------------------------------------------------------------------
289 static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedSize length, CeedScalar val) {
290   for (CeedSize i = 0; i < length; i++) h_array[i] = val;
291   return CEED_ERROR_SUCCESS;
292 }
293 
294 //------------------------------------------------------------------------------
295 // Set device array to value (impl in .hip file)
296 //------------------------------------------------------------------------------
297 int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val);
298 
299 //------------------------------------------------------------------------------
300 // Set a vector to a value
301 //------------------------------------------------------------------------------
302 static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
303   CeedSize        length;
304   CeedVector_Hip *impl;
305 
306   CeedCallBackend(CeedVectorGetData(vec, &impl));
307   CeedCallBackend(CeedVectorGetLength(vec, &length));
308   // Set value for synced device/host array
309   if (!impl->d_array && !impl->h_array) {
310     if (impl->d_array_borrowed) {
311       impl->d_array = impl->d_array_borrowed;
312     } else if (impl->h_array_borrowed) {
313       impl->h_array = impl->h_array_borrowed;
314     } else if (impl->d_array_owned) {
315       impl->d_array = impl->d_array_owned;
316     } else if (impl->h_array_owned) {
317       impl->h_array = impl->h_array_owned;
318     } else {
319       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
320     }
321   }
322   if (impl->d_array) {
323     if (val == 0) {
324       CeedCallHip(CeedVectorReturnCeed(vec), hipMemset(impl->d_array, 0, length * sizeof(CeedScalar)));
325     } else {
326       CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val));
327     }
328     impl->h_array = NULL;
329   } else if (impl->h_array) {
330     CeedCallBackend(CeedHostSetValue_Hip(impl->h_array, length, val));
331     impl->d_array = NULL;
332   }
333   return CEED_ERROR_SUCCESS;
334 }
335 
336 //------------------------------------------------------------------------------
337 // Set host array to value strided
338 //------------------------------------------------------------------------------
339 static int CeedHostSetValueStrided_Hip(CeedScalar *h_array, CeedSize start, CeedSize stop, CeedSize step, CeedScalar val) {
340   for (CeedSize i = start; i < stop; i += step) h_array[i] = val;
341   return CEED_ERROR_SUCCESS;
342 }
343 
344 //------------------------------------------------------------------------------
345 // Set device array to value strided (impl in .hip.cpp file)
346 //------------------------------------------------------------------------------
347 int CeedDeviceSetValueStrided_Hip(CeedScalar *d_array, CeedSize start, CeedSize stop, CeedSize step, CeedScalar val);
348 
349 //------------------------------------------------------------------------------
350 // Set a vector to a value strided
351 //------------------------------------------------------------------------------
352 static int CeedVectorSetValueStrided_Hip(CeedVector vec, CeedSize start, CeedSize stop, CeedSize step, CeedScalar val) {
353   CeedSize        length;
354   CeedVector_Hip *impl;
355 
356   CeedCallBackend(CeedVectorGetData(vec, &impl));
357   CeedCallBackend(CeedVectorGetLength(vec, &length));
358   // Set value for synced device/host array
359   if (stop == -1) stop = length;
360   if (impl->d_array) {
361     CeedCallBackend(CeedDeviceSetValueStrided_Hip(impl->d_array, start, stop, step, val));
362     impl->h_array = NULL;
363   } else if (impl->h_array) {
364     CeedCallBackend(CeedHostSetValueStrided_Hip(impl->h_array, start, stop, step, val));
365     impl->d_array = NULL;
366   } else {
367     return CeedError(CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "CeedVector must have valid data set");
368   }
369   return CEED_ERROR_SUCCESS;
370 }
371 
372 //------------------------------------------------------------------------------
373 // Vector Take Array
374 //------------------------------------------------------------------------------
375 static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
376   CeedVector_Hip *impl;
377 
378   CeedCallBackend(CeedVectorGetData(vec, &impl));
379 
380   // Sync array to requested mem_type
381   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
382 
383   // Update pointer
384   switch (mem_type) {
385     case CEED_MEM_HOST:
386       (*array)               = impl->h_array_borrowed;
387       impl->h_array_borrowed = NULL;
388       impl->h_array          = NULL;
389       break;
390     case CEED_MEM_DEVICE:
391       (*array)               = impl->d_array_borrowed;
392       impl->d_array_borrowed = NULL;
393       impl->d_array          = NULL;
394       break;
395   }
396   return CEED_ERROR_SUCCESS;
397 }
398 
399 //------------------------------------------------------------------------------
400 // Core logic for array syncronization for GetArray.
401 //   If a different memory type is most up to date, this will perform a copy
402 //------------------------------------------------------------------------------
403 static int CeedVectorGetArrayCore_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
404   CeedVector_Hip *impl;
405 
406   CeedCallBackend(CeedVectorGetData(vec, &impl));
407 
408   // Sync array to requested mem_type
409   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
410 
411   // Update pointer
412   switch (mem_type) {
413     case CEED_MEM_HOST:
414       *array = impl->h_array;
415       break;
416     case CEED_MEM_DEVICE:
417       *array = impl->d_array;
418       break;
419   }
420   return CEED_ERROR_SUCCESS;
421 }
422 
423 //------------------------------------------------------------------------------
424 // Get read-only access to a vector via the specified mem_type
425 //------------------------------------------------------------------------------
426 static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
427   return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array);
428 }
429 
430 //------------------------------------------------------------------------------
431 // Get read/write access to a vector via the specified mem_type
432 //------------------------------------------------------------------------------
433 static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
434   CeedVector_Hip *impl;
435 
436   CeedCallBackend(CeedVectorGetData(vec, &impl));
437   CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, mem_type, array));
438   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
439   switch (mem_type) {
440     case CEED_MEM_HOST:
441       impl->h_array = *array;
442       break;
443     case CEED_MEM_DEVICE:
444       impl->d_array = *array;
445       break;
446   }
447   return CEED_ERROR_SUCCESS;
448 }
449 
450 //------------------------------------------------------------------------------
451 // Get write access to a vector via the specified mem_type
452 //------------------------------------------------------------------------------
453 static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
454   bool            has_array_of_type = true;
455   CeedVector_Hip *impl;
456 
457   CeedCallBackend(CeedVectorGetData(vec, &impl));
458   CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type));
459   if (!has_array_of_type) {
460     // Allocate if array is not yet allocated
461     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
462   } else {
463     // Select dirty array
464     switch (mem_type) {
465       case CEED_MEM_HOST:
466         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
467         else impl->h_array = impl->h_array_owned;
468         break;
469       case CEED_MEM_DEVICE:
470         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
471         else impl->d_array = impl->d_array_owned;
472     }
473   }
474   return CeedVectorGetArray_Hip(vec, mem_type, array);
475 }
476 
477 //------------------------------------------------------------------------------
478 // Get the norm of a CeedVector
479 //------------------------------------------------------------------------------
480 static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *norm) {
481   Ceed     ceed;
482   CeedSize length;
483 #if (HIP_VERSION < 60000000)
484   CeedSize num_calls;
485 #endif /* HIP_VERSION */
486   const CeedScalar *d_array;
487   CeedVector_Hip   *impl;
488   hipblasHandle_t   handle;
489 
490   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
491   CeedCallBackend(CeedVectorGetData(vec, &impl));
492   CeedCallBackend(CeedVectorGetLength(vec, &length));
493   CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
494 
495 #if (HIP_VERSION < 60000000)
496   // With ROCm 6, we can use the 64-bit integer interface. Prior to that,
497   // we need to check if the vector is too long to handle with int32,
498   // and if so, divide it into subsections for repeated hipBLAS calls.
499   num_calls = length / INT_MAX;
500   if (length % INT_MAX > 0) num_calls += 1;
501 #endif /* HIP_VERSION */
502 
503   // Compute norm
504   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
505   switch (type) {
506     case CEED_NORM_1: {
507       *norm = 0.0;
508 #if defined(CEED_SCALAR_IS_FP32)
509 #if (HIP_VERSION >= 60000000)  // We have ROCm 6, and can use 64-bit integers
510       CeedCallHipblas(ceed, hipblasSasum_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
511 #else  /* HIP_VERSION */
512       float  sub_norm = 0.0;
513       float *d_array_start;
514 
515       for (CeedInt i = 0; i < num_calls; i++) {
516         d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
517         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
518         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
519 
520         CeedCallHipblas(ceed, cublasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
521         *norm += sub_norm;
522       }
523 #endif /* HIP_VERSION */
524 #else  /* CEED_SCALAR */
525 #if (HIP_VERSION >= 60000000)
526       CeedCallHipblas(ceed, hipblasDasum_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
527 #else  /* HIP_VERSION */
528       double  sub_norm = 0.0;
529       double *d_array_start;
530 
531       for (CeedInt i = 0; i < num_calls; i++) {
532         d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
533         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
534         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
535 
536         CeedCallHipblas(ceed, hipblasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
537         *norm += sub_norm;
538       }
539 #endif /* HIP_VERSION */
540 #endif /* CEED_SCALAR */
541       break;
542     }
543     case CEED_NORM_2: {
544 #if defined(CEED_SCALAR_IS_FP32)
545 #if (HIP_VERSION >= 60000000)
546       CeedCallHipblas(ceed, hipblasSnrm2_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
547 #else  /* CUDA_VERSION */
548       float  sub_norm = 0.0, norm_sum = 0.0;
549       float *d_array_start;
550 
551       for (CeedInt i = 0; i < num_calls; i++) {
552         d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
553         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
554         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
555 
556         CeedCallHipblas(ceed, hipblasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
557         norm_sum += sub_norm * sub_norm;
558       }
559       *norm = sqrt(norm_sum);
560 #endif /* HIP_VERSION */
561 #else  /* CEED_SCALAR */
562 #if (HIP_VERSION >= 60000000)
563       CeedCallHipblas(ceed, hipblasDnrm2_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
564 #else  /* CUDA_VERSION */
565       double  sub_norm = 0.0, norm_sum = 0.0;
566       double *d_array_start;
567 
568       for (CeedInt i = 0; i < num_calls; i++) {
569         d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
570         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
571         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
572 
573         CeedCallHipblas(ceed, hipblasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
574         norm_sum += sub_norm * sub_norm;
575       }
576       *norm = sqrt(norm_sum);
577 #endif /* HIP_VERSION */
578 #endif /* CEED_SCALAR */
579       break;
580     }
581     case CEED_NORM_MAX: {
582 #if defined(CEED_SCALAR_IS_FP32)
583 #if (HIP_VERSION >= 60000000)
584       int64_t    index;
585       CeedScalar norm_no_abs;
586 
587       CeedCallHipblas(ceed, hipblasIsamax_64(handle, (int64_t)length, (float *)d_array, 1, &index));
588       CeedCallHip(ceed, hipMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
589       *norm = fabs(norm_no_abs);
590 #else  /* HIP_VERSION */
591       CeedInt index;
592       float   sub_max = 0.0, current_max = 0.0;
593       float  *d_array_start;
594 
595       for (CeedInt i = 0; i < num_calls; i++) {
596         d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
597         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
598         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
599 
600         CeedCallHipblas(ceed, hipblasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &index));
601         CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
602         if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
603       }
604       *norm = current_max;
605 #endif /* HIP_VERSION */
606 #else  /* CEED_SCALAR */
607 #if (HIP_VERSION >= 60000000)
608       int64_t    index;
609       CeedScalar norm_no_abs;
610 
611       CeedCallHipblas(ceed, hipblasIdamax_64(handle, (int64_t)length, (double *)d_array, 1, &index));
612       CeedCallHip(ceed, hipMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
613       *norm = fabs(norm_no_abs);
614 #else  /* HIP_VERSION */
615       CeedInt index;
616       double  sub_max = 0.0, current_max = 0.0;
617       double *d_array_start;
618 
619       for (CeedInt i = 0; i < num_calls; i++) {
620         d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
621         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
622         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
623 
624         CeedCallHipblas(ceed, hipblasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &index));
625         CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
626         if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
627       }
628       *norm = current_max;
629 #endif /* HIP_VERSION */
630 #endif /* CEED_SCALAR */
631       break;
632     }
633   }
634   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
635   CeedCallBackend(CeedDestroy(&ceed));
636   return CEED_ERROR_SUCCESS;
637 }
638 
639 //------------------------------------------------------------------------------
640 // Take reciprocal of a vector on host
641 //------------------------------------------------------------------------------
642 static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedSize length) {
643   for (CeedSize i = 0; i < length; i++) {
644     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
645   }
646   return CEED_ERROR_SUCCESS;
647 }
648 
649 //------------------------------------------------------------------------------
650 // Take reciprocal of a vector on device (impl in .hip.cpp file)
651 //------------------------------------------------------------------------------
652 int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedSize length);
653 
654 //------------------------------------------------------------------------------
655 // Take reciprocal of a vector
656 //------------------------------------------------------------------------------
657 static int CeedVectorReciprocal_Hip(CeedVector vec) {
658   CeedSize        length;
659   CeedVector_Hip *impl;
660 
661   CeedCallBackend(CeedVectorGetData(vec, &impl));
662   CeedCallBackend(CeedVectorGetLength(vec, &length));
663   // Set value for synced device/host array
664   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Hip(impl->d_array, length));
665   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length));
666   return CEED_ERROR_SUCCESS;
667 }
668 
669 //------------------------------------------------------------------------------
670 // Compute x = alpha x on the host
671 //------------------------------------------------------------------------------
672 static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length) {
673   for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha;
674   return CEED_ERROR_SUCCESS;
675 }
676 
677 //------------------------------------------------------------------------------
678 // Compute x = alpha x on device (impl in .hip.cpp file)
679 //------------------------------------------------------------------------------
680 int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length);
681 
682 //------------------------------------------------------------------------------
683 // Compute x = alpha x
684 //------------------------------------------------------------------------------
685 static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
686   CeedSize        length;
687   CeedVector_Hip *impl;
688 
689   CeedCallBackend(CeedVectorGetData(x, &impl));
690   CeedCallBackend(CeedVectorGetLength(x, &length));
691   // Set value for synced device/host array
692   if (impl->d_array) {
693 #if (HIP_VERSION >= 60000000)
694     hipblasHandle_t handle;
695 
696     CeedCallBackend(CeedGetHipblasHandle_Hip(CeedVectorReturnCeed(x), &handle));
697 #if defined(CEED_SCALAR_IS_FP32)
698     CeedCallHipblas(CeedVectorReturnCeed(x), hipblasSscal_64(handle, (int64_t)length, &alpha, impl->d_array, 1));
699 #else  /* CEED_SCALAR */
700     CeedCallHipblas(CeedVectorReturnCeed(x), hipblasDscal_64(handle, (int64_t)length, &alpha, impl->d_array, 1));
701 #endif /* CEED_SCALAR */
702 #else  /* HIP_VERSION */
703     CeedCallBackend(CeedDeviceScale_Hip(impl->d_array, alpha, length));
704 #endif /* HIP_VERSION */
705     impl->h_array = NULL;
706   }
707   if (impl->h_array) {
708     CeedCallBackend(CeedHostScale_Hip(impl->h_array, alpha, length));
709     impl->d_array = NULL;
710   }
711   return CEED_ERROR_SUCCESS;
712 }
713 
714 //------------------------------------------------------------------------------
715 // Compute y = alpha x + y on the host
716 //------------------------------------------------------------------------------
717 static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) {
718   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
719   return CEED_ERROR_SUCCESS;
720 }
721 
722 //------------------------------------------------------------------------------
723 // Compute y = alpha x + y on device (impl in .hip.cpp file)
724 //------------------------------------------------------------------------------
725 int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length);
726 
727 //------------------------------------------------------------------------------
728 // Compute y = alpha x + y
729 //------------------------------------------------------------------------------
730 static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
731   CeedSize        length;
732   CeedVector_Hip *y_impl, *x_impl;
733 
734   CeedCallBackend(CeedVectorGetData(y, &y_impl));
735   CeedCallBackend(CeedVectorGetData(x, &x_impl));
736   CeedCallBackend(CeedVectorGetLength(y, &length));
737   // Set value for synced device/host array
738   if (y_impl->d_array) {
739     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
740 #if (HIP_VERSION >= 60000000)
741     hipblasHandle_t handle;
742 
743     CeedCallBackend(CeedGetHipblasHandle_Hip(CeedVectorReturnCeed(y), &handle));
744 #if defined(CEED_SCALAR_IS_FP32)
745     CeedCallHipblas(CeedVectorReturnCeed(y), hipblasSaxpy_64(handle, (int64_t)length, &alpha, x_impl->d_array, 1, y_impl->d_array, 1));
746 #else  /* CEED_SCALAR */
747     CeedCallHipblas(CeedVectorReturnCeed(y), hipblasDaxpy_64(handle, (int64_t)length, &alpha, x_impl->d_array, 1, y_impl->d_array, 1));
748 #endif /* CEED_SCALAR */
749 #else  /* HIP_VERSION */
750     CeedCallBackend(CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length));
751 #endif /* HIP_VERSION */
752     y_impl->h_array = NULL;
753   } else if (y_impl->h_array) {
754     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
755     CeedCallBackend(CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length));
756     y_impl->d_array = NULL;
757   }
758   return CEED_ERROR_SUCCESS;
759 }
760 
761 //------------------------------------------------------------------------------
762 // Compute y = alpha x + beta y on the host
763 //------------------------------------------------------------------------------
764 static int CeedHostAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) {
765   for (CeedSize i = 0; i < length; i++) y_array[i] = alpha * x_array[i] + beta * y_array[i];
766   return CEED_ERROR_SUCCESS;
767 }
768 
769 //------------------------------------------------------------------------------
770 // Compute y = alpha x + beta y on device (impl in .hip.cpp file)
771 //------------------------------------------------------------------------------
772 int CeedDeviceAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length);
773 
774 //------------------------------------------------------------------------------
775 // Compute y = alpha x + beta y
776 //------------------------------------------------------------------------------
777 static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
778   CeedSize        length;
779   CeedVector_Hip *y_impl, *x_impl;
780 
781   CeedCallBackend(CeedVectorGetData(y, &y_impl));
782   CeedCallBackend(CeedVectorGetData(x, &x_impl));
783   CeedCallBackend(CeedVectorGetLength(y, &length));
784   // Set value for synced device/host array
785   if (y_impl->d_array) {
786     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
787     CeedCallBackend(CeedDeviceAXPBY_Hip(y_impl->d_array, alpha, beta, x_impl->d_array, length));
788   }
789   if (y_impl->h_array) {
790     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
791     CeedCallBackend(CeedHostAXPBY_Hip(y_impl->h_array, alpha, beta, x_impl->h_array, length));
792   }
793   return CEED_ERROR_SUCCESS;
794 }
795 
796 //------------------------------------------------------------------------------
797 // Compute the pointwise multiplication w = x .* y on the host
798 //------------------------------------------------------------------------------
799 static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) {
800   for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
801   return CEED_ERROR_SUCCESS;
802 }
803 
804 //------------------------------------------------------------------------------
805 // Compute the pointwise multiplication w = x .* y on device (impl in .hip.cpp file)
806 //------------------------------------------------------------------------------
807 int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length);
808 
809 //------------------------------------------------------------------------------
810 // Compute the pointwise multiplication w = x .* y
811 //------------------------------------------------------------------------------
812 static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y) {
813   CeedSize        length;
814   CeedVector_Hip *w_impl, *x_impl, *y_impl;
815 
816   CeedCallBackend(CeedVectorGetData(w, &w_impl));
817   CeedCallBackend(CeedVectorGetData(x, &x_impl));
818   CeedCallBackend(CeedVectorGetData(y, &y_impl));
819   CeedCallBackend(CeedVectorGetLength(w, &length));
820 
821   // Set value for synced device/host array
822   if (!w_impl->d_array && !w_impl->h_array) {
823     CeedCallBackend(CeedVectorSetValue(w, 0.0));
824   }
825   if (w_impl->d_array) {
826     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
827     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
828     CeedCallBackend(CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
829   }
830   if (w_impl->h_array) {
831     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
832     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
833     CeedCallBackend(CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
834   }
835   return CEED_ERROR_SUCCESS;
836 }
837 
838 //------------------------------------------------------------------------------
839 // Destroy the vector
840 //------------------------------------------------------------------------------
841 static int CeedVectorDestroy_Hip(const CeedVector vec) {
842   CeedVector_Hip *impl;
843 
844   CeedCallBackend(CeedVectorGetData(vec, &impl));
845   CeedCallHip(CeedVectorReturnCeed(vec), hipFree(impl->d_array_owned));
846   CeedCallBackend(CeedFree(&impl->h_array_owned));
847   CeedCallBackend(CeedFree(&impl));
848   return CEED_ERROR_SUCCESS;
849 }
850 
851 //------------------------------------------------------------------------------
852 // Create a vector of the specified length (does not allocate memory)
853 //------------------------------------------------------------------------------
854 int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
855   CeedVector_Hip *impl;
856   Ceed            ceed;
857 
858   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
859   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Hip));
860   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Hip));
861   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Hip));
862   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Hip));
863   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "CopyStrided", CeedVectorCopyStrided_Hip));
864   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", CeedVectorSetValue_Hip));
865   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValueStrided", CeedVectorSetValueStrided_Hip));
866   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Hip));
867   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Hip));
868   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Hip));
869   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Hip));
870   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Hip));
871   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Hip));
872   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", CeedVectorScale_Hip));
873   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", CeedVectorAXPY_Hip));
874   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", CeedVectorAXPBY_Hip));
875   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Hip));
876   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Hip));
877   CeedCallBackend(CeedDestroy(&ceed));
878   CeedCallBackend(CeedCalloc(1, &impl));
879   CeedCallBackend(CeedVectorSetData(vec, impl));
880   return CEED_ERROR_SUCCESS;
881 }
882 
883 //------------------------------------------------------------------------------
884