xref: /libCEED/rust/libceed-sys/c-src/backends/hip-ref/ceed-hip-ref-vector.c (revision 832a6d734b42c29ce33664f3c9c828bd26de930d)
1 // Copyright (c) 2017-2025, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <math.h>
11 #include <stdbool.h>
12 #include <string.h>
13 #include <hip/hip_runtime.h>
14 
15 #include "../hip/ceed-hip-common.h"
16 #include "ceed-hip-ref.h"
17 
18 //------------------------------------------------------------------------------
19 // Check if host/device sync is needed
20 //------------------------------------------------------------------------------
21 static inline int CeedVectorNeedSync_Hip(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
22   CeedVector_Hip *impl;
23   bool            has_valid_array = false;
24 
25   CeedCallBackend(CeedVectorGetData(vec, &impl));
26   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
27   switch (mem_type) {
28     case CEED_MEM_HOST:
29       *need_sync = has_valid_array && !impl->h_array;
30       break;
31     case CEED_MEM_DEVICE:
32       *need_sync = has_valid_array && !impl->d_array;
33       break;
34   }
35   return CEED_ERROR_SUCCESS;
36 }
37 
38 //------------------------------------------------------------------------------
39 // Sync host to device
40 //------------------------------------------------------------------------------
41 static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) {
42   CeedSize        length;
43   size_t          bytes;
44   CeedVector_Hip *impl;
45 
46   CeedCallBackend(CeedVectorGetData(vec, &impl));
47 
48   CeedCheck(impl->h_array, CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "No valid host data to sync to device");
49 
50   CeedCallBackend(CeedVectorGetLength(vec, &length));
51   bytes = length * sizeof(CeedScalar);
52   if (impl->d_array_borrowed) {
53     impl->d_array = impl->d_array_borrowed;
54   } else if (impl->d_array_owned) {
55     impl->d_array = impl->d_array_owned;
56   } else {
57     CeedCallHip(CeedVectorReturnCeed(vec), hipMalloc((void **)&impl->d_array_owned, bytes));
58     impl->d_array = impl->d_array_owned;
59   }
60   CeedCallHip(CeedVectorReturnCeed(vec), hipMemcpy(impl->d_array, impl->h_array, bytes, hipMemcpyHostToDevice));
61   return CEED_ERROR_SUCCESS;
62 }
63 
64 //------------------------------------------------------------------------------
65 // Sync device to host
66 //------------------------------------------------------------------------------
67 static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
68   CeedSize        length;
69   size_t          bytes;
70   CeedVector_Hip *impl;
71 
72   CeedCallBackend(CeedVectorGetData(vec, &impl));
73 
74   CeedCheck(impl->d_array, CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "No valid device data to sync to host");
75 
76   if (impl->h_array_borrowed) {
77     impl->h_array = impl->h_array_borrowed;
78   } else if (impl->h_array_owned) {
79     impl->h_array = impl->h_array_owned;
80   } else {
81     CeedSize length;
82 
83     CeedCallBackend(CeedVectorGetLength(vec, &length));
84     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
85     impl->h_array = impl->h_array_owned;
86   }
87 
88   CeedCallBackend(CeedVectorGetLength(vec, &length));
89   bytes = length * sizeof(CeedScalar);
90   CeedCallHip(CeedVectorReturnCeed(vec), hipMemcpy(impl->h_array, impl->d_array, bytes, hipMemcpyDeviceToHost));
91   return CEED_ERROR_SUCCESS;
92 }
93 
94 //------------------------------------------------------------------------------
95 // Sync arrays
96 //------------------------------------------------------------------------------
97 static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) {
98   bool need_sync = false;
99 
100   // Check whether device/host sync is needed
101   CeedCallBackend(CeedVectorNeedSync_Hip(vec, mem_type, &need_sync));
102   if (!need_sync) return CEED_ERROR_SUCCESS;
103 
104   switch (mem_type) {
105     case CEED_MEM_HOST:
106       return CeedVectorSyncD2H_Hip(vec);
107     case CEED_MEM_DEVICE:
108       return CeedVectorSyncH2D_Hip(vec);
109   }
110   return CEED_ERROR_UNSUPPORTED;
111 }
112 
113 //------------------------------------------------------------------------------
114 // Set all pointers as invalid
115 //------------------------------------------------------------------------------
116 static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) {
117   CeedVector_Hip *impl;
118 
119   CeedCallBackend(CeedVectorGetData(vec, &impl));
120   impl->h_array = NULL;
121   impl->d_array = NULL;
122   return CEED_ERROR_SUCCESS;
123 }
124 
125 //------------------------------------------------------------------------------
126 // Check if CeedVector has any valid pointer
127 //------------------------------------------------------------------------------
128 static inline int CeedVectorHasValidArray_Hip(const CeedVector vec, bool *has_valid_array) {
129   CeedVector_Hip *impl;
130 
131   CeedCallBackend(CeedVectorGetData(vec, &impl));
132   *has_valid_array = impl->h_array || impl->d_array;
133   return CEED_ERROR_SUCCESS;
134 }
135 
136 //------------------------------------------------------------------------------
137 // Check if has array of given type
138 //------------------------------------------------------------------------------
139 static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
140   CeedVector_Hip *impl;
141 
142   CeedCallBackend(CeedVectorGetData(vec, &impl));
143   switch (mem_type) {
144     case CEED_MEM_HOST:
145       *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned;
146       break;
147     case CEED_MEM_DEVICE:
148       *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned;
149       break;
150   }
151   return CEED_ERROR_SUCCESS;
152 }
153 
154 //------------------------------------------------------------------------------
155 // Check if has borrowed array of given type
156 //------------------------------------------------------------------------------
157 static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
158   CeedVector_Hip *impl;
159 
160   CeedCallBackend(CeedVectorGetData(vec, &impl));
161   switch (mem_type) {
162     case CEED_MEM_HOST:
163       *has_borrowed_array_of_type = impl->h_array_borrowed;
164       break;
165     case CEED_MEM_DEVICE:
166       *has_borrowed_array_of_type = impl->d_array_borrowed;
167       break;
168   }
169   return CEED_ERROR_SUCCESS;
170 }
171 
172 //------------------------------------------------------------------------------
173 // Set array from host
174 //------------------------------------------------------------------------------
175 static int CeedVectorSetArrayHost_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
176   CeedSize        length;
177   CeedVector_Hip *impl;
178 
179   CeedCallBackend(CeedVectorGetData(vec, &impl));
180   CeedCallBackend(CeedVectorGetLength(vec, &length));
181 
182   CeedCallBackend(CeedSetHostCeedScalarArray(array, copy_mode, length, (const CeedScalar **)&impl->h_array_owned,
183                                              (const CeedScalar **)&impl->h_array_borrowed, (const CeedScalar **)&impl->h_array));
184   return CEED_ERROR_SUCCESS;
185 }
186 
187 //------------------------------------------------------------------------------
188 // Set array from device
189 //------------------------------------------------------------------------------
190 static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
191   CeedSize        length;
192   Ceed            ceed;
193   CeedVector_Hip *impl;
194 
195   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
196   CeedCallBackend(CeedVectorGetData(vec, &impl));
197   CeedCallBackend(CeedVectorGetLength(vec, &length));
198 
199   CeedCallBackend(CeedSetDeviceCeedScalarArray_Hip(ceed, array, copy_mode, length, (const CeedScalar **)&impl->d_array_owned,
200                                                    (const CeedScalar **)&impl->d_array_borrowed, (const CeedScalar **)&impl->d_array));
201   CeedCallBackend(CeedDestroy(&ceed));
202   return CEED_ERROR_SUCCESS;
203 }
204 
205 //------------------------------------------------------------------------------
206 // Set the array used by a vector,
207 //   freeing any previously allocated array if applicable
208 //------------------------------------------------------------------------------
209 static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
210   CeedVector_Hip *impl;
211 
212   CeedCallBackend(CeedVectorGetData(vec, &impl));
213   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
214   switch (mem_type) {
215     case CEED_MEM_HOST:
216       return CeedVectorSetArrayHost_Hip(vec, copy_mode, array);
217     case CEED_MEM_DEVICE:
218       return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array);
219   }
220   return CEED_ERROR_UNSUPPORTED;
221 }
222 
223 //------------------------------------------------------------------------------
224 // Copy host array to value strided
225 //------------------------------------------------------------------------------
226 static int CeedHostCopyStrided_Hip(CeedScalar *h_array, CeedSize start, CeedSize stop, CeedSize step, CeedScalar *h_copy_array) {
227   for (CeedSize i = start; i < stop; i += step) h_copy_array[i] = h_array[i];
228   return CEED_ERROR_SUCCESS;
229 }
230 
231 //------------------------------------------------------------------------------
232 // Copy device array to value strided (impl in .hip.cpp file)
233 //------------------------------------------------------------------------------
234 int CeedDeviceCopyStrided_Hip(CeedScalar *d_array, CeedSize start, CeedSize stop, CeedSize step, CeedScalar *d_copy_array);
235 
236 //------------------------------------------------------------------------------
237 // Copy a vector to a value strided
238 //------------------------------------------------------------------------------
239 static int CeedVectorCopyStrided_Hip(CeedVector vec, CeedSize start, CeedSize stop, CeedSize step, CeedVector vec_copy) {
240   CeedSize        length;
241   CeedVector_Hip *impl;
242 
243   CeedCallBackend(CeedVectorGetData(vec, &impl));
244   {
245     CeedSize length_vec, length_copy;
246 
247     CeedCallBackend(CeedVectorGetLength(vec, &length_vec));
248     CeedCallBackend(CeedVectorGetLength(vec_copy, &length_copy));
249     length = length_vec < length_copy ? length_vec : length_copy;
250   }
251   if (stop == -1) stop = length;
252   // Set value for synced device/host array
253   if (impl->d_array) {
254     CeedScalar *copy_array;
255 
256     CeedCallBackend(CeedVectorGetArray(vec_copy, CEED_MEM_DEVICE, &copy_array));
257 #if (HIP_VERSION >= 60000000)
258     hipblasHandle_t handle;
259     Ceed            ceed;
260 
261     CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
262     CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
263 #if defined(CEED_SCALAR_IS_FP32)
264     CeedCallHipblas(ceed, hipblasScopy_64(handle, (int64_t)(stop - start), impl->d_array + start, (int64_t)step, copy_array + start, (int64_t)step));
265 #else  /* CEED_SCALAR */
266     CeedCallHipblas(ceed, hipblasDcopy_64(handle, (int64_t)(stop - start), impl->d_array + start, (int64_t)step, copy_array + start, (int64_t)step));
267 #endif /* CEED_SCALAR */
268 #else  /* HIP_VERSION */
269     CeedCallBackend(CeedDeviceCopyStrided_Hip(impl->d_array, start, stop, step, copy_array));
270 #endif /* HIP_VERSION */
271     CeedCallBackend(CeedVectorRestoreArray(vec_copy, &copy_array));
272     impl->h_array = NULL;
273     CeedCallBackend(CeedDestroy(&ceed));
274   } else if (impl->h_array) {
275     CeedScalar *copy_array;
276 
277     CeedCallBackend(CeedVectorGetArray(vec_copy, CEED_MEM_HOST, &copy_array));
278     CeedCallBackend(CeedHostCopyStrided_Hip(impl->h_array, start, stop, step, copy_array));
279     CeedCallBackend(CeedVectorRestoreArray(vec_copy, &copy_array));
280     impl->d_array = NULL;
281   } else {
282     return CeedError(CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "CeedVector must have valid data set");
283   }
284   return CEED_ERROR_SUCCESS;
285 }
286 
287 //------------------------------------------------------------------------------
288 // Set host array to value
289 //------------------------------------------------------------------------------
290 static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedSize length, CeedScalar val) {
291   for (CeedSize i = 0; i < length; i++) h_array[i] = val;
292   return CEED_ERROR_SUCCESS;
293 }
294 
295 //------------------------------------------------------------------------------
296 // Set device array to value (impl in .hip file)
297 //------------------------------------------------------------------------------
298 int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val);
299 
300 //------------------------------------------------------------------------------
301 // Set a vector to a value
302 //------------------------------------------------------------------------------
303 static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
304   CeedSize        length;
305   CeedVector_Hip *impl;
306 
307   CeedCallBackend(CeedVectorGetData(vec, &impl));
308   CeedCallBackend(CeedVectorGetLength(vec, &length));
309   // Set value for synced device/host array
310   if (!impl->d_array && !impl->h_array) {
311     if (impl->d_array_borrowed) {
312       impl->d_array = impl->d_array_borrowed;
313     } else if (impl->h_array_borrowed) {
314       impl->h_array = impl->h_array_borrowed;
315     } else if (impl->d_array_owned) {
316       impl->d_array = impl->d_array_owned;
317     } else if (impl->h_array_owned) {
318       impl->h_array = impl->h_array_owned;
319     } else {
320       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
321     }
322   }
323   if (impl->d_array) {
324     if (val == 0) {
325       CeedCallHip(CeedVectorReturnCeed(vec), hipMemset(impl->d_array, 0, length * sizeof(CeedScalar)));
326     } else {
327       CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val));
328     }
329     impl->h_array = NULL;
330   } else if (impl->h_array) {
331     CeedCallBackend(CeedHostSetValue_Hip(impl->h_array, length, val));
332     impl->d_array = NULL;
333   }
334   return CEED_ERROR_SUCCESS;
335 }
336 
337 //------------------------------------------------------------------------------
338 // Set host array to value strided
339 //------------------------------------------------------------------------------
340 static int CeedHostSetValueStrided_Hip(CeedScalar *h_array, CeedSize start, CeedSize stop, CeedSize step, CeedScalar val) {
341   for (CeedSize i = start; i < stop; i += step) h_array[i] = val;
342   return CEED_ERROR_SUCCESS;
343 }
344 
345 //------------------------------------------------------------------------------
346 // Set device array to value strided (impl in .hip.cpp file)
347 //------------------------------------------------------------------------------
348 int CeedDeviceSetValueStrided_Hip(CeedScalar *d_array, CeedSize start, CeedSize stop, CeedSize step, CeedScalar val);
349 
350 //------------------------------------------------------------------------------
351 // Set a vector to a value strided
352 //------------------------------------------------------------------------------
353 static int CeedVectorSetValueStrided_Hip(CeedVector vec, CeedSize start, CeedSize stop, CeedSize step, CeedScalar val) {
354   CeedSize        length;
355   CeedVector_Hip *impl;
356 
357   CeedCallBackend(CeedVectorGetData(vec, &impl));
358   CeedCallBackend(CeedVectorGetLength(vec, &length));
359   // Set value for synced device/host array
360   if (stop == -1) stop = length;
361   if (impl->d_array) {
362     CeedCallBackend(CeedDeviceSetValueStrided_Hip(impl->d_array, start, stop, step, val));
363     impl->h_array = NULL;
364   } else if (impl->h_array) {
365     CeedCallBackend(CeedHostSetValueStrided_Hip(impl->h_array, start, stop, step, val));
366     impl->d_array = NULL;
367   } else {
368     return CeedError(CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "CeedVector must have valid data set");
369   }
370   return CEED_ERROR_SUCCESS;
371 }
372 
373 //------------------------------------------------------------------------------
374 // Vector Take Array
375 //------------------------------------------------------------------------------
376 static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
377   CeedVector_Hip *impl;
378 
379   CeedCallBackend(CeedVectorGetData(vec, &impl));
380 
381   // Sync array to requested mem_type
382   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
383 
384   // Update pointer
385   switch (mem_type) {
386     case CEED_MEM_HOST:
387       (*array)               = impl->h_array_borrowed;
388       impl->h_array_borrowed = NULL;
389       impl->h_array          = NULL;
390       break;
391     case CEED_MEM_DEVICE:
392       (*array)               = impl->d_array_borrowed;
393       impl->d_array_borrowed = NULL;
394       impl->d_array          = NULL;
395       break;
396   }
397   return CEED_ERROR_SUCCESS;
398 }
399 
400 //------------------------------------------------------------------------------
401 // Core logic for array syncronization for GetArray.
402 //   If a different memory type is most up to date, this will perform a copy
403 //------------------------------------------------------------------------------
404 static int CeedVectorGetArrayCore_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
405   CeedVector_Hip *impl;
406 
407   CeedCallBackend(CeedVectorGetData(vec, &impl));
408 
409   // Sync array to requested mem_type
410   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
411 
412   // Update pointer
413   switch (mem_type) {
414     case CEED_MEM_HOST:
415       *array = impl->h_array;
416       break;
417     case CEED_MEM_DEVICE:
418       *array = impl->d_array;
419       break;
420   }
421   return CEED_ERROR_SUCCESS;
422 }
423 
424 //------------------------------------------------------------------------------
425 // Get read-only access to a vector via the specified mem_type
426 //------------------------------------------------------------------------------
427 static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
428   return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array);
429 }
430 
431 //------------------------------------------------------------------------------
432 // Get read/write access to a vector via the specified mem_type
433 //------------------------------------------------------------------------------
434 static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
435   CeedVector_Hip *impl;
436 
437   CeedCallBackend(CeedVectorGetData(vec, &impl));
438   CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, mem_type, array));
439   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
440   switch (mem_type) {
441     case CEED_MEM_HOST:
442       impl->h_array = *array;
443       break;
444     case CEED_MEM_DEVICE:
445       impl->d_array = *array;
446       break;
447   }
448   return CEED_ERROR_SUCCESS;
449 }
450 
451 //------------------------------------------------------------------------------
452 // Get write access to a vector via the specified mem_type
453 //------------------------------------------------------------------------------
454 static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
455   bool            has_array_of_type = true;
456   CeedVector_Hip *impl;
457 
458   CeedCallBackend(CeedVectorGetData(vec, &impl));
459   CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type));
460   if (!has_array_of_type) {
461     // Allocate if array is not yet allocated
462     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
463   } else {
464     // Select dirty array
465     switch (mem_type) {
466       case CEED_MEM_HOST:
467         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
468         else impl->h_array = impl->h_array_owned;
469         break;
470       case CEED_MEM_DEVICE:
471         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
472         else impl->d_array = impl->d_array_owned;
473     }
474   }
475   return CeedVectorGetArray_Hip(vec, mem_type, array);
476 }
477 
478 //------------------------------------------------------------------------------
479 // Get the norm of a CeedVector
480 //------------------------------------------------------------------------------
481 static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *norm) {
482   Ceed     ceed;
483   CeedSize length;
484 #if (HIP_VERSION < 60000000)
485   CeedSize num_calls;
486 #endif /* HIP_VERSION */
487   const CeedScalar *d_array;
488   CeedVector_Hip   *impl;
489   hipblasHandle_t   handle;
490 
491   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
492   CeedCallBackend(CeedVectorGetData(vec, &impl));
493   CeedCallBackend(CeedVectorGetLength(vec, &length));
494   CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
495 
496 #if (HIP_VERSION < 60000000)
497   // With ROCm 6, we can use the 64-bit integer interface. Prior to that,
498   // we need to check if the vector is too long to handle with int32,
499   // and if so, divide it into subsections for repeated hipBLAS calls.
500   num_calls = length / INT_MAX;
501   if (length % INT_MAX > 0) num_calls += 1;
502 #endif /* HIP_VERSION */
503 
504   // Compute norm
505   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
506   switch (type) {
507     case CEED_NORM_1: {
508       *norm = 0.0;
509 #if defined(CEED_SCALAR_IS_FP32)
510 #if (HIP_VERSION >= 60000000)  // We have ROCm 6, and can use 64-bit integers
511       CeedCallHipblas(ceed, hipblasSasum_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
512 #else  /* HIP_VERSION */
513       float  sub_norm = 0.0;
514       float *d_array_start;
515 
516       for (CeedInt i = 0; i < num_calls; i++) {
517         d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
518         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
519         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
520 
521         CeedCallHipblas(ceed, cublasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
522         *norm += sub_norm;
523       }
524 #endif /* HIP_VERSION */
525 #else  /* CEED_SCALAR */
526 #if (HIP_VERSION >= 60000000)
527       CeedCallHipblas(ceed, hipblasDasum_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
528 #else  /* HIP_VERSION */
529       double  sub_norm = 0.0;
530       double *d_array_start;
531 
532       for (CeedInt i = 0; i < num_calls; i++) {
533         d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
534         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
535         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
536 
537         CeedCallHipblas(ceed, hipblasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
538         *norm += sub_norm;
539       }
540 #endif /* HIP_VERSION */
541 #endif /* CEED_SCALAR */
542       break;
543     }
544     case CEED_NORM_2: {
545 #if defined(CEED_SCALAR_IS_FP32)
546 #if (HIP_VERSION >= 60000000)
547       CeedCallHipblas(ceed, hipblasSnrm2_64(handle, (int64_t)length, (float *)d_array, 1, (float *)norm));
548 #else  /* CUDA_VERSION */
549       float  sub_norm = 0.0, norm_sum = 0.0;
550       float *d_array_start;
551 
552       for (CeedInt i = 0; i < num_calls; i++) {
553         d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
554         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
555         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
556 
557         CeedCallHipblas(ceed, hipblasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
558         norm_sum += sub_norm * sub_norm;
559       }
560       *norm = sqrt(norm_sum);
561 #endif /* HIP_VERSION */
562 #else  /* CEED_SCALAR */
563 #if (HIP_VERSION >= 60000000)
564       CeedCallHipblas(ceed, hipblasDnrm2_64(handle, (int64_t)length, (double *)d_array, 1, (double *)norm));
565 #else  /* CUDA_VERSION */
566       double  sub_norm = 0.0, norm_sum = 0.0;
567       double *d_array_start;
568 
569       for (CeedInt i = 0; i < num_calls; i++) {
570         d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
571         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
572         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
573 
574         CeedCallHipblas(ceed, hipblasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
575         norm_sum += sub_norm * sub_norm;
576       }
577       *norm = sqrt(norm_sum);
578 #endif /* HIP_VERSION */
579 #endif /* CEED_SCALAR */
580       break;
581     }
582     case CEED_NORM_MAX: {
583 #if defined(CEED_SCALAR_IS_FP32)
584 #if (HIP_VERSION >= 60000000)
585       int64_t    index;
586       CeedScalar norm_no_abs;
587 
588       CeedCallHipblas(ceed, hipblasIsamax_64(handle, (int64_t)length, (float *)d_array, 1, &index));
589       CeedCallHip(ceed, hipMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
590       *norm = fabs(norm_no_abs);
591 #else  /* HIP_VERSION */
592       CeedInt index;
593       float   sub_max = 0.0, current_max = 0.0;
594       float  *d_array_start;
595 
596       for (CeedInt i = 0; i < num_calls; i++) {
597         d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
598         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
599         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
600 
601         CeedCallHipblas(ceed, hipblasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &index));
602         CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
603         if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
604       }
605       *norm = current_max;
606 #endif /* HIP_VERSION */
607 #else  /* CEED_SCALAR */
608 #if (HIP_VERSION >= 60000000)
609       int64_t    index;
610       CeedScalar norm_no_abs;
611 
612       CeedCallHipblas(ceed, hipblasIdamax_64(handle, (int64_t)length, (double *)d_array, 1, &index));
613       CeedCallHip(ceed, hipMemcpy(&norm_no_abs, impl->d_array + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
614       *norm = fabs(norm_no_abs);
615 #else  /* HIP_VERSION */
616       CeedInt index;
617       double  sub_max = 0.0, current_max = 0.0;
618       double *d_array_start;
619 
620       for (CeedInt i = 0; i < num_calls; i++) {
621         d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
622         CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
623         CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
624 
625         CeedCallHipblas(ceed, hipblasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &index));
626         CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
627         if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
628       }
629       *norm = current_max;
630 #endif /* HIP_VERSION */
631 #endif /* CEED_SCALAR */
632       break;
633     }
634   }
635   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
636   CeedCallBackend(CeedDestroy(&ceed));
637   return CEED_ERROR_SUCCESS;
638 }
639 
640 //------------------------------------------------------------------------------
641 // Take reciprocal of a vector on host
642 //------------------------------------------------------------------------------
643 static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedSize length) {
644   for (CeedSize i = 0; i < length; i++) {
645     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
646   }
647   return CEED_ERROR_SUCCESS;
648 }
649 
650 //------------------------------------------------------------------------------
651 // Take reciprocal of a vector on device (impl in .hip.cpp file)
652 //------------------------------------------------------------------------------
653 int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedSize length);
654 
655 //------------------------------------------------------------------------------
656 // Take reciprocal of a vector
657 //------------------------------------------------------------------------------
658 static int CeedVectorReciprocal_Hip(CeedVector vec) {
659   CeedSize        length;
660   CeedVector_Hip *impl;
661 
662   CeedCallBackend(CeedVectorGetData(vec, &impl));
663   CeedCallBackend(CeedVectorGetLength(vec, &length));
664   // Set value for synced device/host array
665   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Hip(impl->d_array, length));
666   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length));
667   return CEED_ERROR_SUCCESS;
668 }
669 
670 //------------------------------------------------------------------------------
671 // Compute x = alpha x on the host
672 //------------------------------------------------------------------------------
673 static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length) {
674   for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha;
675   return CEED_ERROR_SUCCESS;
676 }
677 
678 //------------------------------------------------------------------------------
679 // Compute x = alpha x on device (impl in .hip.cpp file)
680 //------------------------------------------------------------------------------
681 int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length);
682 
683 //------------------------------------------------------------------------------
684 // Compute x = alpha x
685 //------------------------------------------------------------------------------
686 static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
687   CeedSize        length;
688   CeedVector_Hip *impl;
689 
690   CeedCallBackend(CeedVectorGetData(x, &impl));
691   CeedCallBackend(CeedVectorGetLength(x, &length));
692   // Set value for synced device/host array
693   if (impl->d_array) {
694 #if (HIP_VERSION >= 60000000)
695     hipblasHandle_t handle;
696 
697     CeedCallBackend(CeedGetHipblasHandle_Hip(CeedVectorReturnCeed(x), &handle));
698 #if defined(CEED_SCALAR_IS_FP32)
699     CeedCallHipblas(CeedVectorReturnCeed(x), hipblasSscal_64(handle, (int64_t)length, &alpha, impl->d_array, 1));
700 #else  /* CEED_SCALAR */
701     CeedCallHipblas(CeedVectorReturnCeed(x), hipblasDscal_64(handle, (int64_t)length, &alpha, impl->d_array, 1));
702 #endif /* CEED_SCALAR */
703 #else  /* HIP_VERSION */
704     CeedCallBackend(CeedDeviceScale_Hip(impl->d_array, alpha, length));
705 #endif /* HIP_VERSION */
706     impl->h_array = NULL;
707   }
708   if (impl->h_array) {
709     CeedCallBackend(CeedHostScale_Hip(impl->h_array, alpha, length));
710     impl->d_array = NULL;
711   }
712   return CEED_ERROR_SUCCESS;
713 }
714 
715 //------------------------------------------------------------------------------
716 // Compute y = alpha x + y on the host
717 //------------------------------------------------------------------------------
718 static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) {
719   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
720   return CEED_ERROR_SUCCESS;
721 }
722 
723 //------------------------------------------------------------------------------
724 // Compute y = alpha x + y on device (impl in .hip.cpp file)
725 //------------------------------------------------------------------------------
726 int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length);
727 
728 //------------------------------------------------------------------------------
729 // Compute y = alpha x + y
730 //------------------------------------------------------------------------------
731 static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
732   CeedSize        length;
733   CeedVector_Hip *y_impl, *x_impl;
734 
735   CeedCallBackend(CeedVectorGetData(y, &y_impl));
736   CeedCallBackend(CeedVectorGetData(x, &x_impl));
737   CeedCallBackend(CeedVectorGetLength(y, &length));
738   // Set value for synced device/host array
739   if (y_impl->d_array) {
740     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
741 #if (HIP_VERSION >= 60000000)
742     hipblasHandle_t handle;
743 
744     CeedCallBackend(CeedGetHipblasHandle_Hip(CeedVectorReturnCeed(y), &handle));
745 #if defined(CEED_SCALAR_IS_FP32)
746     CeedCallHipblas(CeedVectorReturnCeed(y), hipblasSaxpy_64(handle, (int64_t)length, &alpha, x_impl->d_array, 1, y_impl->d_array, 1));
747 #else  /* CEED_SCALAR */
748     CeedCallHipblas(CeedVectorReturnCeed(y), hipblasDaxpy_64(handle, (int64_t)length, &alpha, x_impl->d_array, 1, y_impl->d_array, 1));
749 #endif /* CEED_SCALAR */
750 #else  /* HIP_VERSION */
751     CeedCallBackend(CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length));
752 #endif /* HIP_VERSION */
753     y_impl->h_array = NULL;
754   } else if (y_impl->h_array) {
755     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
756     CeedCallBackend(CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length));
757     y_impl->d_array = NULL;
758   }
759   return CEED_ERROR_SUCCESS;
760 }
761 
762 //------------------------------------------------------------------------------
763 // Compute y = alpha x + beta y on the host
764 //------------------------------------------------------------------------------
765 static int CeedHostAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) {
766   for (CeedSize i = 0; i < length; i++) y_array[i] = alpha * x_array[i] + beta * y_array[i];
767   return CEED_ERROR_SUCCESS;
768 }
769 
770 //------------------------------------------------------------------------------
771 // Compute y = alpha x + beta y on device (impl in .hip.cpp file)
772 //------------------------------------------------------------------------------
773 int CeedDeviceAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length);
774 
775 //------------------------------------------------------------------------------
776 // Compute y = alpha x + beta y
777 //------------------------------------------------------------------------------
778 static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
779   CeedSize        length;
780   CeedVector_Hip *y_impl, *x_impl;
781 
782   CeedCallBackend(CeedVectorGetData(y, &y_impl));
783   CeedCallBackend(CeedVectorGetData(x, &x_impl));
784   CeedCallBackend(CeedVectorGetLength(y, &length));
785   // Set value for synced device/host array
786   if (y_impl->d_array) {
787     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
788     CeedCallBackend(CeedDeviceAXPBY_Hip(y_impl->d_array, alpha, beta, x_impl->d_array, length));
789   }
790   if (y_impl->h_array) {
791     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
792     CeedCallBackend(CeedHostAXPBY_Hip(y_impl->h_array, alpha, beta, x_impl->h_array, length));
793   }
794   return CEED_ERROR_SUCCESS;
795 }
796 
797 //------------------------------------------------------------------------------
798 // Compute the pointwise multiplication w = x .* y on the host
799 //------------------------------------------------------------------------------
800 static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) {
801   for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
802   return CEED_ERROR_SUCCESS;
803 }
804 
805 //------------------------------------------------------------------------------
806 // Compute the pointwise multiplication w = x .* y on device (impl in .hip.cpp file)
807 //------------------------------------------------------------------------------
808 int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length);
809 
810 //------------------------------------------------------------------------------
811 // Compute the pointwise multiplication w = x .* y
812 //------------------------------------------------------------------------------
813 static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y) {
814   CeedSize        length;
815   CeedVector_Hip *w_impl, *x_impl, *y_impl;
816 
817   CeedCallBackend(CeedVectorGetData(w, &w_impl));
818   CeedCallBackend(CeedVectorGetData(x, &x_impl));
819   CeedCallBackend(CeedVectorGetData(y, &y_impl));
820   CeedCallBackend(CeedVectorGetLength(w, &length));
821 
822   // Set value for synced device/host array
823   if (!w_impl->d_array && !w_impl->h_array) {
824     CeedCallBackend(CeedVectorSetValue(w, 0.0));
825   }
826   if (w_impl->d_array) {
827     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
828     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
829     CeedCallBackend(CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
830   }
831   if (w_impl->h_array) {
832     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
833     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
834     CeedCallBackend(CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
835   }
836   return CEED_ERROR_SUCCESS;
837 }
838 
839 //------------------------------------------------------------------------------
840 // Destroy the vector
841 //------------------------------------------------------------------------------
842 static int CeedVectorDestroy_Hip(const CeedVector vec) {
843   CeedVector_Hip *impl;
844 
845   CeedCallBackend(CeedVectorGetData(vec, &impl));
846   CeedCallHip(CeedVectorReturnCeed(vec), hipFree(impl->d_array_owned));
847   CeedCallBackend(CeedFree(&impl->h_array_owned));
848   CeedCallBackend(CeedFree(&impl));
849   return CEED_ERROR_SUCCESS;
850 }
851 
852 //------------------------------------------------------------------------------
853 // Create a vector of the specified length (does not allocate memory)
854 //------------------------------------------------------------------------------
855 int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
856   CeedVector_Hip *impl;
857   Ceed            ceed;
858 
859   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
860   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Hip));
861   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Hip));
862   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Hip));
863   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Hip));
864   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "CopyStrided", CeedVectorCopyStrided_Hip));
865   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", CeedVectorSetValue_Hip));
866   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValueStrided", CeedVectorSetValueStrided_Hip));
867   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Hip));
868   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Hip));
869   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Hip));
870   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Hip));
871   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Hip));
872   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Hip));
873   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", CeedVectorScale_Hip));
874   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", CeedVectorAXPY_Hip));
875   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", CeedVectorAXPBY_Hip));
876   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Hip));
877   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Hip));
878   CeedCallBackend(CeedDestroy(&ceed));
879   CeedCallBackend(CeedCalloc(1, &impl));
880   CeedCallBackend(CeedVectorSetData(vec, impl));
881   return CEED_ERROR_SUCCESS;
882 }
883 
884 //------------------------------------------------------------------------------
885