xref: /libCEED/backends/hip-ref/ceed-hip-ref-vector.c (revision 19a04db8f2245ce5b2896f00805cba58aafa343d)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <math.h>
11 #include <stdbool.h>
12 #include <string.h>
13 #include <hip/hip_runtime.h>
14 
15 #include "../hip/ceed-hip-common.h"
16 #include "ceed-hip-ref.h"
17 
18 //------------------------------------------------------------------------------
19 // Check if host/device sync is needed
20 //------------------------------------------------------------------------------
21 static inline int CeedVectorNeedSync_Hip(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
22   CeedVector_Hip *impl;
23   bool            has_valid_array = false;
24 
25   CeedCallBackend(CeedVectorGetData(vec, &impl));
26   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
27   switch (mem_type) {
28     case CEED_MEM_HOST:
29       *need_sync = has_valid_array && !impl->h_array;
30       break;
31     case CEED_MEM_DEVICE:
32       *need_sync = has_valid_array && !impl->d_array;
33       break;
34   }
35   return CEED_ERROR_SUCCESS;
36 }
37 
38 //------------------------------------------------------------------------------
39 // Sync host to device
40 //------------------------------------------------------------------------------
41 static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) {
42   Ceed            ceed;
43   CeedSize        length;
44   size_t          bytes;
45   CeedVector_Hip *impl;
46 
47   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
48   CeedCallBackend(CeedVectorGetData(vec, &impl));
49 
50   CeedCheck(impl->h_array, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
51 
52   CeedCallBackend(CeedVectorGetLength(vec, &length));
53   bytes = length * sizeof(CeedScalar);
54   if (impl->d_array_borrowed) {
55     impl->d_array = impl->d_array_borrowed;
56   } else if (impl->d_array_owned) {
57     impl->d_array = impl->d_array_owned;
58   } else {
59     CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes));
60     impl->d_array = impl->d_array_owned;
61   }
62   CeedCallHip(ceed, hipMemcpy(impl->d_array, impl->h_array, bytes, hipMemcpyHostToDevice));
63   return CEED_ERROR_SUCCESS;
64 }
65 
66 //------------------------------------------------------------------------------
67 // Sync device to host
68 //------------------------------------------------------------------------------
69 static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
70   Ceed            ceed;
71   CeedSize        length;
72   size_t          bytes;
73   CeedVector_Hip *impl;
74 
75   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
76   CeedCallBackend(CeedVectorGetData(vec, &impl));
77 
78   CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
79 
80   if (impl->h_array_borrowed) {
81     impl->h_array = impl->h_array_borrowed;
82   } else if (impl->h_array_owned) {
83     impl->h_array = impl->h_array_owned;
84   } else {
85     CeedSize length;
86 
87     CeedCallBackend(CeedVectorGetLength(vec, &length));
88     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
89     impl->h_array = impl->h_array_owned;
90   }
91 
92   CeedCallBackend(CeedVectorGetLength(vec, &length));
93   bytes = length * sizeof(CeedScalar);
94   CeedCallHip(ceed, hipMemcpy(impl->h_array, impl->d_array, bytes, hipMemcpyDeviceToHost));
95   return CEED_ERROR_SUCCESS;
96 }
97 
98 //------------------------------------------------------------------------------
99 // Sync arrays
100 //------------------------------------------------------------------------------
101 static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) {
102   bool need_sync = false;
103 
104   // Check whether device/host sync is needed
105   CeedCallBackend(CeedVectorNeedSync_Hip(vec, mem_type, &need_sync));
106   if (!need_sync) return CEED_ERROR_SUCCESS;
107 
108   switch (mem_type) {
109     case CEED_MEM_HOST:
110       return CeedVectorSyncD2H_Hip(vec);
111     case CEED_MEM_DEVICE:
112       return CeedVectorSyncH2D_Hip(vec);
113   }
114   return CEED_ERROR_UNSUPPORTED;
115 }
116 
117 //------------------------------------------------------------------------------
118 // Set all pointers as invalid
119 //------------------------------------------------------------------------------
120 static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) {
121   CeedVector_Hip *impl;
122 
123   CeedCallBackend(CeedVectorGetData(vec, &impl));
124   impl->h_array = NULL;
125   impl->d_array = NULL;
126   return CEED_ERROR_SUCCESS;
127 }
128 
129 //------------------------------------------------------------------------------
130 // Check if CeedVector has any valid pointer
131 //------------------------------------------------------------------------------
132 static inline int CeedVectorHasValidArray_Hip(const CeedVector vec, bool *has_valid_array) {
133   CeedVector_Hip *impl;
134 
135   CeedCallBackend(CeedVectorGetData(vec, &impl));
136   *has_valid_array = impl->h_array || impl->d_array;
137   return CEED_ERROR_SUCCESS;
138 }
139 
140 //------------------------------------------------------------------------------
141 // Check if has array of given type
142 //------------------------------------------------------------------------------
143 static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
144   CeedVector_Hip *impl;
145 
146   CeedCallBackend(CeedVectorGetData(vec, &impl));
147   switch (mem_type) {
148     case CEED_MEM_HOST:
149       *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned;
150       break;
151     case CEED_MEM_DEVICE:
152       *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned;
153       break;
154   }
155   return CEED_ERROR_SUCCESS;
156 }
157 
158 //------------------------------------------------------------------------------
159 // Check if has borrowed array of given type
160 //------------------------------------------------------------------------------
161 static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
162   CeedVector_Hip *impl;
163 
164   CeedCallBackend(CeedVectorGetData(vec, &impl));
165   switch (mem_type) {
166     case CEED_MEM_HOST:
167       *has_borrowed_array_of_type = impl->h_array_borrowed;
168       break;
169     case CEED_MEM_DEVICE:
170       *has_borrowed_array_of_type = impl->d_array_borrowed;
171       break;
172   }
173   return CEED_ERROR_SUCCESS;
174 }
175 
176 //------------------------------------------------------------------------------
177 // Set array from host
178 //------------------------------------------------------------------------------
179 static int CeedVectorSetArrayHost_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
180   CeedSize        length;
181   CeedVector_Hip *impl;
182 
183   CeedCallBackend(CeedVectorGetData(vec, &impl));
184   CeedCallBackend(CeedVectorGetLength(vec, &length));
185 
186   CeedCallBackend(CeedSetHostCeedScalarArray(array, copy_mode, length, (const CeedScalar **)&impl->h_array_owned,
187                                              (const CeedScalar **)&impl->h_array_borrowed, (const CeedScalar **)&impl->h_array));
188   return CEED_ERROR_SUCCESS;
189 }
190 
191 //------------------------------------------------------------------------------
192 // Set array from device
193 //------------------------------------------------------------------------------
194 static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
195   CeedSize        length;
196   Ceed            ceed;
197   CeedVector_Hip *impl;
198 
199   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
200   CeedCallBackend(CeedVectorGetData(vec, &impl));
201   CeedCallBackend(CeedVectorGetLength(vec, &length));
202 
203   CeedCallBackend(CeedSetDeviceCeedScalarArray_Hip(ceed, array, copy_mode, length, (const CeedScalar **)&impl->d_array_owned,
204                                                    (const CeedScalar **)&impl->d_array_borrowed, (const CeedScalar **)&impl->d_array));
205   return CEED_ERROR_SUCCESS;
206 }
207 
208 //------------------------------------------------------------------------------
209 // Set the array used by a vector,
210 //   freeing any previously allocated array if applicable
211 //------------------------------------------------------------------------------
212 static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
213   CeedVector_Hip *impl;
214 
215   CeedCallBackend(CeedVectorGetData(vec, &impl));
216   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
217   switch (mem_type) {
218     case CEED_MEM_HOST:
219       return CeedVectorSetArrayHost_Hip(vec, copy_mode, array);
220     case CEED_MEM_DEVICE:
221       return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array);
222   }
223   return CEED_ERROR_UNSUPPORTED;
224 }
225 
226 //------------------------------------------------------------------------------
227 // Copy host array to value strided
228 //------------------------------------------------------------------------------
229 static int CeedHostCopyStrided_Hip(CeedScalar *h_array, CeedSize start, CeedSize step, CeedSize length, CeedScalar *h_copy_array) {
230   for (CeedSize i = start; i < length; i += step) h_copy_array[i] = h_array[i];
231   return CEED_ERROR_SUCCESS;
232 }
233 
234 //------------------------------------------------------------------------------
235 // Copy device array to value strided (impl in .hip.cpp file)
236 //------------------------------------------------------------------------------
237 int CeedDeviceCopyStrided_Hip(CeedScalar *d_array, CeedSize start, CeedSize step, CeedSize length, CeedScalar *d_copy_array);
238 
239 //------------------------------------------------------------------------------
240 // Copy a vector to a value strided
241 //------------------------------------------------------------------------------
242 static int CeedVectorCopyStrided_Hip(CeedVector vec, CeedSize start, CeedSize step, CeedVector vec_copy) {
243   CeedSize        length;
244   CeedVector_Hip *impl;
245 
246   CeedCallBackend(CeedVectorGetData(vec, &impl));
247   {
248     CeedSize length_vec, length_copy;
249 
250     CeedCallBackend(CeedVectorGetLength(vec, &length_vec));
251     CeedCallBackend(CeedVectorGetLength(vec_copy, &length_copy));
252     length = length_vec < length_copy ? length_vec : length_copy;
253   }
254   // Set value for synced device/host array
255   if (impl->d_array) {
256     CeedScalar *copy_array;
257 
258     CeedCallBackend(CeedVectorGetArray(vec_copy, CEED_MEM_DEVICE, &copy_array));
259     CeedCallBackend(CeedDeviceCopyStrided_Hip(impl->d_array, start, step, length, copy_array));
260     CeedCallBackend(CeedVectorRestoreArray(vec_copy, &copy_array));
261   } else if (impl->h_array) {
262     CeedScalar *copy_array;
263 
264     CeedCallBackend(CeedVectorGetArray(vec_copy, CEED_MEM_HOST, &copy_array));
265     CeedCallBackend(CeedHostCopyStrided_Hip(impl->h_array, start, step, length, copy_array));
266     CeedCallBackend(CeedVectorRestoreArray(vec_copy, &copy_array));
267   } else {
268     return CeedError(CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "CeedVector must have valid data set");
269   }
270   return CEED_ERROR_SUCCESS;
271 }
272 
273 //------------------------------------------------------------------------------
274 // Set host array to value
275 //------------------------------------------------------------------------------
276 static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedSize length, CeedScalar val) {
277   for (CeedSize i = 0; i < length; i++) h_array[i] = val;
278   return CEED_ERROR_SUCCESS;
279 }
280 
281 //------------------------------------------------------------------------------
282 // Set device array to value (impl in .hip file)
283 //------------------------------------------------------------------------------
284 int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val);
285 
286 //------------------------------------------------------------------------------
287 // Set a vector to a value
288 //------------------------------------------------------------------------------
289 static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
290   CeedSize        length;
291   CeedVector_Hip *impl;
292 
293   CeedCallBackend(CeedVectorGetData(vec, &impl));
294   CeedCallBackend(CeedVectorGetLength(vec, &length));
295   // Set value for synced device/host array
296   if (!impl->d_array && !impl->h_array) {
297     if (impl->d_array_borrowed) {
298       impl->d_array = impl->d_array_borrowed;
299     } else if (impl->h_array_borrowed) {
300       impl->h_array = impl->h_array_borrowed;
301     } else if (impl->d_array_owned) {
302       impl->d_array = impl->d_array_owned;
303     } else if (impl->h_array_owned) {
304       impl->h_array = impl->h_array_owned;
305     } else {
306       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
307     }
308   }
309   if (impl->d_array) {
310     CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val));
311     impl->h_array = NULL;
312   }
313   if (impl->h_array) {
314     CeedCallBackend(CeedHostSetValue_Hip(impl->h_array, length, val));
315     impl->d_array = NULL;
316   }
317   return CEED_ERROR_SUCCESS;
318 }
319 
320 //------------------------------------------------------------------------------
321 // Set host array to value strided
322 //------------------------------------------------------------------------------
323 static int CeedHostSetValueStrided_Hip(CeedScalar *h_array, CeedSize start, CeedSize step, CeedSize length, CeedScalar val) {
324   for (CeedSize i = start; i < length; i += step) h_array[i] = val;
325   return CEED_ERROR_SUCCESS;
326 }
327 
328 //------------------------------------------------------------------------------
329 // Set device array to value strided (impl in .hip.cpp file)
330 //------------------------------------------------------------------------------
331 int CeedDeviceSetValueStrided_Hip(CeedScalar *d_array, CeedSize start, CeedSize step, CeedSize length, CeedScalar val);
332 
333 //------------------------------------------------------------------------------
334 // Set a vector to a value strided
335 //------------------------------------------------------------------------------
336 static int CeedVectorSetValueStrided_Hip(CeedVector vec, CeedSize start, CeedSize step, CeedScalar val) {
337   CeedSize        length;
338   CeedVector_Hip *impl;
339 
340   CeedCallBackend(CeedVectorGetData(vec, &impl));
341   CeedCallBackend(CeedVectorGetLength(vec, &length));
342   // Set value for synced device/host array
343   if (impl->d_array) {
344     CeedCallBackend(CeedDeviceSetValueStrided_Hip(impl->d_array, start, step, length, val));
345     impl->h_array = NULL;
346   } else if (impl->h_array) {
347     CeedCallBackend(CeedHostSetValueStrided_Hip(impl->h_array, start, step, length, val));
348     impl->d_array = NULL;
349   } else {
350     return CeedError(CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "CeedVector must have valid data set");
351   }
352   return CEED_ERROR_SUCCESS;
353 }
354 
355 //------------------------------------------------------------------------------
356 // Vector Take Array
357 //------------------------------------------------------------------------------
358 static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
359   CeedVector_Hip *impl;
360 
361   CeedCallBackend(CeedVectorGetData(vec, &impl));
362 
363   // Sync array to requested mem_type
364   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
365 
366   // Update pointer
367   switch (mem_type) {
368     case CEED_MEM_HOST:
369       (*array)               = impl->h_array_borrowed;
370       impl->h_array_borrowed = NULL;
371       impl->h_array          = NULL;
372       break;
373     case CEED_MEM_DEVICE:
374       (*array)               = impl->d_array_borrowed;
375       impl->d_array_borrowed = NULL;
376       impl->d_array          = NULL;
377       break;
378   }
379   return CEED_ERROR_SUCCESS;
380 }
381 
382 //------------------------------------------------------------------------------
383 // Core logic for array syncronization for GetArray.
384 //   If a different memory type is most up to date, this will perform a copy
385 //------------------------------------------------------------------------------
386 static int CeedVectorGetArrayCore_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
387   CeedVector_Hip *impl;
388 
389   CeedCallBackend(CeedVectorGetData(vec, &impl));
390 
391   // Sync array to requested mem_type
392   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
393 
394   // Update pointer
395   switch (mem_type) {
396     case CEED_MEM_HOST:
397       *array = impl->h_array;
398       break;
399     case CEED_MEM_DEVICE:
400       *array = impl->d_array;
401       break;
402   }
403   return CEED_ERROR_SUCCESS;
404 }
405 
406 //------------------------------------------------------------------------------
407 // Get read-only access to a vector via the specified mem_type
408 //------------------------------------------------------------------------------
409 static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
410   return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array);
411 }
412 
413 //------------------------------------------------------------------------------
414 // Get read/write access to a vector via the specified mem_type
415 //------------------------------------------------------------------------------
416 static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
417   CeedVector_Hip *impl;
418 
419   CeedCallBackend(CeedVectorGetData(vec, &impl));
420   CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, mem_type, array));
421   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
422   switch (mem_type) {
423     case CEED_MEM_HOST:
424       impl->h_array = *array;
425       break;
426     case CEED_MEM_DEVICE:
427       impl->d_array = *array;
428       break;
429   }
430   return CEED_ERROR_SUCCESS;
431 }
432 
433 //------------------------------------------------------------------------------
434 // Get write access to a vector via the specified mem_type
435 //------------------------------------------------------------------------------
436 static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
437   bool            has_array_of_type = true;
438   CeedVector_Hip *impl;
439 
440   CeedCallBackend(CeedVectorGetData(vec, &impl));
441   CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type));
442   if (!has_array_of_type) {
443     // Allocate if array is not yet allocated
444     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
445   } else {
446     // Select dirty array
447     switch (mem_type) {
448       case CEED_MEM_HOST:
449         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
450         else impl->h_array = impl->h_array_owned;
451         break;
452       case CEED_MEM_DEVICE:
453         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
454         else impl->d_array = impl->d_array_owned;
455     }
456   }
457   return CeedVectorGetArray_Hip(vec, mem_type, array);
458 }
459 
460 //------------------------------------------------------------------------------
461 // Get the norm of a CeedVector
462 //------------------------------------------------------------------------------
463 static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *norm) {
464   Ceed              ceed;
465   CeedSize          length, num_calls;
466   const CeedScalar *d_array;
467   CeedVector_Hip   *impl;
468   hipblasHandle_t   handle;
469 
470   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
471   CeedCallBackend(CeedVectorGetData(vec, &impl));
472   CeedCallBackend(CeedVectorGetLength(vec, &length));
473   CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
474 
475   // Is the vector too long to handle with int32? If so, we will divide
476   // it up into "int32-sized" subsections and make repeated BLAS calls.
477   num_calls = length / INT_MAX;
478   if (length % INT_MAX > 0) num_calls += 1;
479 
480   // Compute norm
481   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
482   switch (type) {
483     case CEED_NORM_1: {
484       *norm = 0.0;
485       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
486         float  sub_norm = 0.0;
487         float *d_array_start;
488 
489         for (CeedInt i = 0; i < num_calls; i++) {
490           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
491           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
492           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
493 
494           CeedCallHipblas(ceed, hipblasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
495           *norm += sub_norm;
496         }
497       } else {
498         double  sub_norm = 0.0;
499         double *d_array_start;
500 
501         for (CeedInt i = 0; i < num_calls; i++) {
502           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
503           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
504           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
505 
506           CeedCallHipblas(ceed, hipblasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
507           *norm += sub_norm;
508         }
509       }
510       break;
511     }
512     case CEED_NORM_2: {
513       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
514         float  sub_norm = 0.0, norm_sum = 0.0;
515         float *d_array_start;
516 
517         for (CeedInt i = 0; i < num_calls; i++) {
518           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
519           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
520           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
521 
522           CeedCallHipblas(ceed, hipblasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
523           norm_sum += sub_norm * sub_norm;
524         }
525         *norm = sqrt(norm_sum);
526       } else {
527         double  sub_norm = 0.0, norm_sum = 0.0;
528         double *d_array_start;
529 
530         for (CeedInt i = 0; i < num_calls; i++) {
531           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
532           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
533           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
534 
535           CeedCallHipblas(ceed, hipblasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
536           norm_sum += sub_norm * sub_norm;
537         }
538         *norm = sqrt(norm_sum);
539       }
540       break;
541     }
542     case CEED_NORM_MAX: {
543       CeedInt index;
544 
545       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
546         float  sub_max = 0.0, current_max = 0.0;
547         float *d_array_start;
548         for (CeedInt i = 0; i < num_calls; i++) {
549           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
550           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
551           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
552 
553           CeedCallHipblas(ceed, hipblasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &index));
554           CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
555           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
556         }
557         *norm = current_max;
558       } else {
559         double  sub_max = 0.0, current_max = 0.0;
560         double *d_array_start;
561 
562         for (CeedInt i = 0; i < num_calls; i++) {
563           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
564           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
565           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
566 
567           CeedCallHipblas(ceed, hipblasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &index));
568           CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
569           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
570         }
571         *norm = current_max;
572       }
573       break;
574     }
575   }
576   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
577   return CEED_ERROR_SUCCESS;
578 }
579 
580 //------------------------------------------------------------------------------
581 // Take reciprocal of a vector on host
582 //------------------------------------------------------------------------------
583 static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedSize length) {
584   for (CeedSize i = 0; i < length; i++) {
585     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
586   }
587   return CEED_ERROR_SUCCESS;
588 }
589 
590 //------------------------------------------------------------------------------
591 // Take reciprocal of a vector on device (impl in .hip.cpp file)
592 //------------------------------------------------------------------------------
593 int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedSize length);
594 
595 //------------------------------------------------------------------------------
596 // Take reciprocal of a vector
597 //------------------------------------------------------------------------------
598 static int CeedVectorReciprocal_Hip(CeedVector vec) {
599   CeedSize        length;
600   CeedVector_Hip *impl;
601 
602   CeedCallBackend(CeedVectorGetData(vec, &impl));
603   CeedCallBackend(CeedVectorGetLength(vec, &length));
604   // Set value for synced device/host array
605   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Hip(impl->d_array, length));
606   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length));
607   return CEED_ERROR_SUCCESS;
608 }
609 
610 //------------------------------------------------------------------------------
611 // Compute x = alpha x on the host
612 //------------------------------------------------------------------------------
613 static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length) {
614   for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha;
615   return CEED_ERROR_SUCCESS;
616 }
617 
618 //------------------------------------------------------------------------------
619 // Compute x = alpha x on device (impl in .hip.cpp file)
620 //------------------------------------------------------------------------------
621 int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length);
622 
623 //------------------------------------------------------------------------------
624 // Compute x = alpha x
625 //------------------------------------------------------------------------------
626 static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
627   CeedSize        length;
628   CeedVector_Hip *x_impl;
629 
630   CeedCallBackend(CeedVectorGetData(x, &x_impl));
631   CeedCallBackend(CeedVectorGetLength(x, &length));
632   // Set value for synced device/host array
633   if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Hip(x_impl->d_array, alpha, length));
634   if (x_impl->h_array) CeedCallBackend(CeedHostScale_Hip(x_impl->h_array, alpha, length));
635   return CEED_ERROR_SUCCESS;
636 }
637 
638 //------------------------------------------------------------------------------
639 // Compute y = alpha x + y on the host
640 //------------------------------------------------------------------------------
641 static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) {
642   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
643   return CEED_ERROR_SUCCESS;
644 }
645 
646 //------------------------------------------------------------------------------
647 // Compute y = alpha x + y on device (impl in .hip.cpp file)
648 //------------------------------------------------------------------------------
649 int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length);
650 
651 //------------------------------------------------------------------------------
652 // Compute y = alpha x + y
653 //------------------------------------------------------------------------------
654 static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
655   CeedSize        length;
656   CeedVector_Hip *y_impl, *x_impl;
657 
658   CeedCallBackend(CeedVectorGetData(y, &y_impl));
659   CeedCallBackend(CeedVectorGetData(x, &x_impl));
660   CeedCallBackend(CeedVectorGetLength(y, &length));
661   // Set value for synced device/host array
662   if (y_impl->d_array) {
663     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
664     CeedCallBackend(CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length));
665   }
666   if (y_impl->h_array) {
667     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
668     CeedCallBackend(CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length));
669   }
670   return CEED_ERROR_SUCCESS;
671 }
672 
673 //------------------------------------------------------------------------------
674 // Compute y = alpha x + beta y on the host
675 //------------------------------------------------------------------------------
676 static int CeedHostAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) {
677   for (CeedSize i = 0; i < length; i++) y_array[i] = alpha * x_array[i] + beta * y_array[i];
678   return CEED_ERROR_SUCCESS;
679 }
680 
681 //------------------------------------------------------------------------------
682 // Compute y = alpha x + beta y on device (impl in .hip.cpp file)
683 //------------------------------------------------------------------------------
684 int CeedDeviceAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length);
685 
686 //------------------------------------------------------------------------------
687 // Compute y = alpha x + beta y
688 //------------------------------------------------------------------------------
689 static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
690   CeedSize        length;
691   CeedVector_Hip *y_impl, *x_impl;
692 
693   CeedCallBackend(CeedVectorGetData(y, &y_impl));
694   CeedCallBackend(CeedVectorGetData(x, &x_impl));
695   CeedCallBackend(CeedVectorGetLength(y, &length));
696   // Set value for synced device/host array
697   if (y_impl->d_array) {
698     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
699     CeedCallBackend(CeedDeviceAXPBY_Hip(y_impl->d_array, alpha, beta, x_impl->d_array, length));
700   }
701   if (y_impl->h_array) {
702     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
703     CeedCallBackend(CeedHostAXPBY_Hip(y_impl->h_array, alpha, beta, x_impl->h_array, length));
704   }
705   return CEED_ERROR_SUCCESS;
706 }
707 
708 //------------------------------------------------------------------------------
709 // Compute the pointwise multiplication w = x .* y on the host
710 //------------------------------------------------------------------------------
711 static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) {
712   for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
713   return CEED_ERROR_SUCCESS;
714 }
715 
716 //------------------------------------------------------------------------------
717 // Compute the pointwise multiplication w = x .* y on device (impl in .hip.cpp file)
718 //------------------------------------------------------------------------------
719 int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length);
720 
721 //------------------------------------------------------------------------------
722 // Compute the pointwise multiplication w = x .* y
723 //------------------------------------------------------------------------------
724 static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y) {
725   CeedSize        length;
726   CeedVector_Hip *w_impl, *x_impl, *y_impl;
727 
728   CeedCallBackend(CeedVectorGetData(w, &w_impl));
729   CeedCallBackend(CeedVectorGetData(x, &x_impl));
730   CeedCallBackend(CeedVectorGetData(y, &y_impl));
731   CeedCallBackend(CeedVectorGetLength(w, &length));
732 
733   // Set value for synced device/host array
734   if (!w_impl->d_array && !w_impl->h_array) {
735     CeedCallBackend(CeedVectorSetValue(w, 0.0));
736   }
737   if (w_impl->d_array) {
738     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
739     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
740     CeedCallBackend(CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
741   }
742   if (w_impl->h_array) {
743     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
744     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
745     CeedCallBackend(CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
746   }
747   return CEED_ERROR_SUCCESS;
748 }
749 
750 //------------------------------------------------------------------------------
751 // Destroy the vector
752 //------------------------------------------------------------------------------
753 static int CeedVectorDestroy_Hip(const CeedVector vec) {
754   CeedVector_Hip *impl;
755 
756   CeedCallBackend(CeedVectorGetData(vec, &impl));
757   CeedCallHip(CeedVectorReturnCeed(vec), hipFree(impl->d_array_owned));
758   CeedCallBackend(CeedFree(&impl->h_array_owned));
759   CeedCallBackend(CeedFree(&impl));
760   return CEED_ERROR_SUCCESS;
761 }
762 
763 //------------------------------------------------------------------------------
764 // Create a vector of the specified length (does not allocate memory)
765 //------------------------------------------------------------------------------
766 int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
767   CeedVector_Hip *impl;
768   Ceed            ceed;
769 
770   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
771   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Hip));
772   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Hip));
773   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Hip));
774   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Hip));
775   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "CopyStrided", CeedVectorCopyStrided_Hip));
776   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", CeedVectorSetValue_Hip));
777   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValueStrided", CeedVectorSetValueStrided_Hip));
778   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Hip));
779   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Hip));
780   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Hip));
781   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Hip));
782   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Hip));
783   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Hip));
784   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", CeedVectorScale_Hip));
785   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", CeedVectorAXPY_Hip));
786   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", CeedVectorAXPBY_Hip));
787   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Hip));
788   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Hip));
789   CeedCallBackend(CeedCalloc(1, &impl));
790   CeedCallBackend(CeedVectorSetData(vec, impl));
791   return CEED_ERROR_SUCCESS;
792 }
793 
794 //------------------------------------------------------------------------------
795