xref: /libCEED/backends/hip-ref/ceed-hip-ref-vector.c (revision e3ae47f6083a97dd785ae32cf6cad3c05e295e6c)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <math.h>
11 #include <stdbool.h>
12 #include <string.h>
13 #include <hip/hip_runtime.h>
14 
15 #include "../hip/ceed-hip-common.h"
16 #include "ceed-hip-ref.h"
17 
18 //------------------------------------------------------------------------------
19 // Check if host/device sync is needed
20 //------------------------------------------------------------------------------
21 static inline int CeedVectorNeedSync_Hip(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
22   CeedVector_Hip *impl;
23   bool            has_valid_array = false;
24 
25   CeedCallBackend(CeedVectorGetData(vec, &impl));
26   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
27   switch (mem_type) {
28     case CEED_MEM_HOST:
29       *need_sync = has_valid_array && !impl->h_array;
30       break;
31     case CEED_MEM_DEVICE:
32       *need_sync = has_valid_array && !impl->d_array;
33       break;
34   }
35   return CEED_ERROR_SUCCESS;
36 }
37 
38 //------------------------------------------------------------------------------
39 // Sync host to device
40 //------------------------------------------------------------------------------
41 static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) {
42   CeedSize        length;
43   size_t          bytes;
44   CeedVector_Hip *impl;
45 
46   CeedCallBackend(CeedVectorGetData(vec, &impl));
47 
48   CeedCheck(impl->h_array, CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "No valid host data to sync to device");
49 
50   CeedCallBackend(CeedVectorGetLength(vec, &length));
51   bytes = length * sizeof(CeedScalar);
52   if (impl->d_array_borrowed) {
53     impl->d_array = impl->d_array_borrowed;
54   } else if (impl->d_array_owned) {
55     impl->d_array = impl->d_array_owned;
56   } else {
57     CeedCallHip(CeedVectorReturnCeed(vec), hipMalloc((void **)&impl->d_array_owned, bytes));
58     impl->d_array = impl->d_array_owned;
59   }
60   CeedCallHip(CeedVectorReturnCeed(vec), hipMemcpy(impl->d_array, impl->h_array, bytes, hipMemcpyHostToDevice));
61   return CEED_ERROR_SUCCESS;
62 }
63 
64 //------------------------------------------------------------------------------
65 // Sync device to host
66 //------------------------------------------------------------------------------
67 static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
68   CeedSize        length;
69   size_t          bytes;
70   CeedVector_Hip *impl;
71 
72   CeedCallBackend(CeedVectorGetData(vec, &impl));
73 
74   CeedCheck(impl->d_array, CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "No valid device data to sync to host");
75 
76   if (impl->h_array_borrowed) {
77     impl->h_array = impl->h_array_borrowed;
78   } else if (impl->h_array_owned) {
79     impl->h_array = impl->h_array_owned;
80   } else {
81     CeedSize length;
82 
83     CeedCallBackend(CeedVectorGetLength(vec, &length));
84     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
85     impl->h_array = impl->h_array_owned;
86   }
87 
88   CeedCallBackend(CeedVectorGetLength(vec, &length));
89   bytes = length * sizeof(CeedScalar);
90   CeedCallHip(CeedVectorReturnCeed(vec), hipMemcpy(impl->h_array, impl->d_array, bytes, hipMemcpyDeviceToHost));
91   return CEED_ERROR_SUCCESS;
92 }
93 
94 //------------------------------------------------------------------------------
95 // Sync arrays
96 //------------------------------------------------------------------------------
97 static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) {
98   bool need_sync = false;
99 
100   // Check whether device/host sync is needed
101   CeedCallBackend(CeedVectorNeedSync_Hip(vec, mem_type, &need_sync));
102   if (!need_sync) return CEED_ERROR_SUCCESS;
103 
104   switch (mem_type) {
105     case CEED_MEM_HOST:
106       return CeedVectorSyncD2H_Hip(vec);
107     case CEED_MEM_DEVICE:
108       return CeedVectorSyncH2D_Hip(vec);
109   }
110   return CEED_ERROR_UNSUPPORTED;
111 }
112 
113 //------------------------------------------------------------------------------
114 // Set all pointers as invalid
115 //------------------------------------------------------------------------------
116 static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) {
117   CeedVector_Hip *impl;
118 
119   CeedCallBackend(CeedVectorGetData(vec, &impl));
120   impl->h_array = NULL;
121   impl->d_array = NULL;
122   return CEED_ERROR_SUCCESS;
123 }
124 
125 //------------------------------------------------------------------------------
126 // Check if CeedVector has any valid pointer
127 //------------------------------------------------------------------------------
128 static inline int CeedVectorHasValidArray_Hip(const CeedVector vec, bool *has_valid_array) {
129   CeedVector_Hip *impl;
130 
131   CeedCallBackend(CeedVectorGetData(vec, &impl));
132   *has_valid_array = impl->h_array || impl->d_array;
133   return CEED_ERROR_SUCCESS;
134 }
135 
136 //------------------------------------------------------------------------------
137 // Check if has array of given type
138 //------------------------------------------------------------------------------
139 static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
140   CeedVector_Hip *impl;
141 
142   CeedCallBackend(CeedVectorGetData(vec, &impl));
143   switch (mem_type) {
144     case CEED_MEM_HOST:
145       *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned;
146       break;
147     case CEED_MEM_DEVICE:
148       *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned;
149       break;
150   }
151   return CEED_ERROR_SUCCESS;
152 }
153 
154 //------------------------------------------------------------------------------
155 // Check if has borrowed array of given type
156 //------------------------------------------------------------------------------
157 static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
158   CeedVector_Hip *impl;
159 
160   CeedCallBackend(CeedVectorGetData(vec, &impl));
161   switch (mem_type) {
162     case CEED_MEM_HOST:
163       *has_borrowed_array_of_type = impl->h_array_borrowed;
164       break;
165     case CEED_MEM_DEVICE:
166       *has_borrowed_array_of_type = impl->d_array_borrowed;
167       break;
168   }
169   return CEED_ERROR_SUCCESS;
170 }
171 
172 //------------------------------------------------------------------------------
173 // Set array from host
174 //------------------------------------------------------------------------------
175 static int CeedVectorSetArrayHost_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
176   CeedSize        length;
177   CeedVector_Hip *impl;
178 
179   CeedCallBackend(CeedVectorGetData(vec, &impl));
180   CeedCallBackend(CeedVectorGetLength(vec, &length));
181 
182   CeedCallBackend(CeedSetHostCeedScalarArray(array, copy_mode, length, (const CeedScalar **)&impl->h_array_owned,
183                                              (const CeedScalar **)&impl->h_array_borrowed, (const CeedScalar **)&impl->h_array));
184   return CEED_ERROR_SUCCESS;
185 }
186 
187 //------------------------------------------------------------------------------
188 // Set array from device
189 //------------------------------------------------------------------------------
190 static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
191   CeedSize        length;
192   Ceed            ceed;
193   CeedVector_Hip *impl;
194 
195   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
196   CeedCallBackend(CeedVectorGetData(vec, &impl));
197   CeedCallBackend(CeedVectorGetLength(vec, &length));
198 
199   CeedCallBackend(CeedSetDeviceCeedScalarArray_Hip(ceed, array, copy_mode, length, (const CeedScalar **)&impl->d_array_owned,
200                                                    (const CeedScalar **)&impl->d_array_borrowed, (const CeedScalar **)&impl->d_array));
201   CeedCallBackend(CeedDestroy(&ceed));
202   return CEED_ERROR_SUCCESS;
203 }
204 
205 //------------------------------------------------------------------------------
206 // Set the array used by a vector,
207 //   freeing any previously allocated array if applicable
208 //------------------------------------------------------------------------------
209 static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
210   CeedVector_Hip *impl;
211 
212   CeedCallBackend(CeedVectorGetData(vec, &impl));
213   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
214   switch (mem_type) {
215     case CEED_MEM_HOST:
216       return CeedVectorSetArrayHost_Hip(vec, copy_mode, array);
217     case CEED_MEM_DEVICE:
218       return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array);
219   }
220   return CEED_ERROR_UNSUPPORTED;
221 }
222 
223 //------------------------------------------------------------------------------
224 // Copy host array to value strided
225 //------------------------------------------------------------------------------
226 static int CeedHostCopyStrided_Hip(CeedScalar *h_array, CeedSize start, CeedSize step, CeedSize length, CeedScalar *h_copy_array) {
227   for (CeedSize i = start; i < length; i += step) h_copy_array[i] = h_array[i];
228   return CEED_ERROR_SUCCESS;
229 }
230 
231 //------------------------------------------------------------------------------
232 // Copy device array to value strided (impl in .hip.cpp file)
233 //------------------------------------------------------------------------------
234 int CeedDeviceCopyStrided_Hip(CeedScalar *d_array, CeedSize start, CeedSize step, CeedSize length, CeedScalar *d_copy_array);
235 
236 //------------------------------------------------------------------------------
237 // Copy a vector to a value strided
238 //------------------------------------------------------------------------------
239 static int CeedVectorCopyStrided_Hip(CeedVector vec, CeedSize start, CeedSize step, CeedVector vec_copy) {
240   CeedSize        length;
241   CeedVector_Hip *impl;
242 
243   CeedCallBackend(CeedVectorGetData(vec, &impl));
244   {
245     CeedSize length_vec, length_copy;
246 
247     CeedCallBackend(CeedVectorGetLength(vec, &length_vec));
248     CeedCallBackend(CeedVectorGetLength(vec_copy, &length_copy));
249     length = length_vec < length_copy ? length_vec : length_copy;
250   }
251   // Set value for synced device/host array
252   if (impl->d_array) {
253     CeedScalar *copy_array;
254 
255     CeedCallBackend(CeedVectorGetArray(vec_copy, CEED_MEM_DEVICE, &copy_array));
256     CeedCallBackend(CeedDeviceCopyStrided_Hip(impl->d_array, start, step, length, copy_array));
257     CeedCallBackend(CeedVectorRestoreArray(vec_copy, &copy_array));
258   } else if (impl->h_array) {
259     CeedScalar *copy_array;
260 
261     CeedCallBackend(CeedVectorGetArray(vec_copy, CEED_MEM_HOST, &copy_array));
262     CeedCallBackend(CeedHostCopyStrided_Hip(impl->h_array, start, step, length, copy_array));
263     CeedCallBackend(CeedVectorRestoreArray(vec_copy, &copy_array));
264   } else {
265     return CeedError(CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "CeedVector must have valid data set");
266   }
267   return CEED_ERROR_SUCCESS;
268 }
269 
270 //------------------------------------------------------------------------------
271 // Set host array to value
272 //------------------------------------------------------------------------------
273 static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedSize length, CeedScalar val) {
274   for (CeedSize i = 0; i < length; i++) h_array[i] = val;
275   return CEED_ERROR_SUCCESS;
276 }
277 
278 //------------------------------------------------------------------------------
279 // Set device array to value (impl in .hip file)
280 //------------------------------------------------------------------------------
281 int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val);
282 
283 //------------------------------------------------------------------------------
284 // Set a vector to a value
285 //------------------------------------------------------------------------------
286 static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
287   CeedSize        length;
288   CeedVector_Hip *impl;
289 
290   CeedCallBackend(CeedVectorGetData(vec, &impl));
291   CeedCallBackend(CeedVectorGetLength(vec, &length));
292   // Set value for synced device/host array
293   if (!impl->d_array && !impl->h_array) {
294     if (impl->d_array_borrowed) {
295       impl->d_array = impl->d_array_borrowed;
296     } else if (impl->h_array_borrowed) {
297       impl->h_array = impl->h_array_borrowed;
298     } else if (impl->d_array_owned) {
299       impl->d_array = impl->d_array_owned;
300     } else if (impl->h_array_owned) {
301       impl->h_array = impl->h_array_owned;
302     } else {
303       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
304     }
305   }
306   if (impl->d_array) {
307     CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val));
308     impl->h_array = NULL;
309   }
310   if (impl->h_array) {
311     CeedCallBackend(CeedHostSetValue_Hip(impl->h_array, length, val));
312     impl->d_array = NULL;
313   }
314   return CEED_ERROR_SUCCESS;
315 }
316 
317 //------------------------------------------------------------------------------
318 // Set host array to value strided
319 //------------------------------------------------------------------------------
320 static int CeedHostSetValueStrided_Hip(CeedScalar *h_array, CeedSize start, CeedSize step, CeedSize length, CeedScalar val) {
321   for (CeedSize i = start; i < length; i += step) h_array[i] = val;
322   return CEED_ERROR_SUCCESS;
323 }
324 
325 //------------------------------------------------------------------------------
326 // Set device array to value strided (impl in .hip.cpp file)
327 //------------------------------------------------------------------------------
328 int CeedDeviceSetValueStrided_Hip(CeedScalar *d_array, CeedSize start, CeedSize step, CeedSize length, CeedScalar val);
329 
330 //------------------------------------------------------------------------------
331 // Set a vector to a value strided
332 //------------------------------------------------------------------------------
333 static int CeedVectorSetValueStrided_Hip(CeedVector vec, CeedSize start, CeedSize step, CeedScalar val) {
334   CeedSize        length;
335   CeedVector_Hip *impl;
336 
337   CeedCallBackend(CeedVectorGetData(vec, &impl));
338   CeedCallBackend(CeedVectorGetLength(vec, &length));
339   // Set value for synced device/host array
340   if (impl->d_array) {
341     CeedCallBackend(CeedDeviceSetValueStrided_Hip(impl->d_array, start, step, length, val));
342     impl->h_array = NULL;
343   } else if (impl->h_array) {
344     CeedCallBackend(CeedHostSetValueStrided_Hip(impl->h_array, start, step, length, val));
345     impl->d_array = NULL;
346   } else {
347     return CeedError(CeedVectorReturnCeed(vec), CEED_ERROR_BACKEND, "CeedVector must have valid data set");
348   }
349   return CEED_ERROR_SUCCESS;
350 }
351 
352 //------------------------------------------------------------------------------
353 // Vector Take Array
354 //------------------------------------------------------------------------------
355 static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
356   CeedVector_Hip *impl;
357 
358   CeedCallBackend(CeedVectorGetData(vec, &impl));
359 
360   // Sync array to requested mem_type
361   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
362 
363   // Update pointer
364   switch (mem_type) {
365     case CEED_MEM_HOST:
366       (*array)               = impl->h_array_borrowed;
367       impl->h_array_borrowed = NULL;
368       impl->h_array          = NULL;
369       break;
370     case CEED_MEM_DEVICE:
371       (*array)               = impl->d_array_borrowed;
372       impl->d_array_borrowed = NULL;
373       impl->d_array          = NULL;
374       break;
375   }
376   return CEED_ERROR_SUCCESS;
377 }
378 
379 //------------------------------------------------------------------------------
380 // Core logic for array syncronization for GetArray.
381 //   If a different memory type is most up to date, this will perform a copy
382 //------------------------------------------------------------------------------
383 static int CeedVectorGetArrayCore_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
384   CeedVector_Hip *impl;
385 
386   CeedCallBackend(CeedVectorGetData(vec, &impl));
387 
388   // Sync array to requested mem_type
389   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
390 
391   // Update pointer
392   switch (mem_type) {
393     case CEED_MEM_HOST:
394       *array = impl->h_array;
395       break;
396     case CEED_MEM_DEVICE:
397       *array = impl->d_array;
398       break;
399   }
400   return CEED_ERROR_SUCCESS;
401 }
402 
403 //------------------------------------------------------------------------------
404 // Get read-only access to a vector via the specified mem_type
405 //------------------------------------------------------------------------------
406 static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
407   return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array);
408 }
409 
410 //------------------------------------------------------------------------------
411 // Get read/write access to a vector via the specified mem_type
412 //------------------------------------------------------------------------------
413 static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
414   CeedVector_Hip *impl;
415 
416   CeedCallBackend(CeedVectorGetData(vec, &impl));
417   CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, mem_type, array));
418   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
419   switch (mem_type) {
420     case CEED_MEM_HOST:
421       impl->h_array = *array;
422       break;
423     case CEED_MEM_DEVICE:
424       impl->d_array = *array;
425       break;
426   }
427   return CEED_ERROR_SUCCESS;
428 }
429 
430 //------------------------------------------------------------------------------
431 // Get write access to a vector via the specified mem_type
432 //------------------------------------------------------------------------------
433 static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
434   bool            has_array_of_type = true;
435   CeedVector_Hip *impl;
436 
437   CeedCallBackend(CeedVectorGetData(vec, &impl));
438   CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type));
439   if (!has_array_of_type) {
440     // Allocate if array is not yet allocated
441     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
442   } else {
443     // Select dirty array
444     switch (mem_type) {
445       case CEED_MEM_HOST:
446         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
447         else impl->h_array = impl->h_array_owned;
448         break;
449       case CEED_MEM_DEVICE:
450         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
451         else impl->d_array = impl->d_array_owned;
452     }
453   }
454   return CeedVectorGetArray_Hip(vec, mem_type, array);
455 }
456 
457 //------------------------------------------------------------------------------
458 // Get the norm of a CeedVector
459 //------------------------------------------------------------------------------
460 static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *norm) {
461   Ceed              ceed;
462   CeedSize          length, num_calls;
463   const CeedScalar *d_array;
464   CeedVector_Hip   *impl;
465   hipblasHandle_t   handle;
466 
467   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
468   CeedCallBackend(CeedVectorGetData(vec, &impl));
469   CeedCallBackend(CeedVectorGetLength(vec, &length));
470   CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
471 
472   // Is the vector too long to handle with int32? If so, we will divide
473   // it up into "int32-sized" subsections and make repeated BLAS calls.
474   num_calls = length / INT_MAX;
475   if (length % INT_MAX > 0) num_calls += 1;
476 
477   // Compute norm
478   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
479   switch (type) {
480     case CEED_NORM_1: {
481       *norm = 0.0;
482       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
483         float  sub_norm = 0.0;
484         float *d_array_start;
485 
486         for (CeedInt i = 0; i < num_calls; i++) {
487           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
488           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
489           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
490 
491           CeedCallHipblas(ceed, hipblasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
492           *norm += sub_norm;
493         }
494       } else {
495         double  sub_norm = 0.0;
496         double *d_array_start;
497 
498         for (CeedInt i = 0; i < num_calls; i++) {
499           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
500           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
501           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
502 
503           CeedCallHipblas(ceed, hipblasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
504           *norm += sub_norm;
505         }
506       }
507       break;
508     }
509     case CEED_NORM_2: {
510       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
511         float  sub_norm = 0.0, norm_sum = 0.0;
512         float *d_array_start;
513 
514         for (CeedInt i = 0; i < num_calls; i++) {
515           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
516           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
517           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
518 
519           CeedCallHipblas(ceed, hipblasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
520           norm_sum += sub_norm * sub_norm;
521         }
522         *norm = sqrt(norm_sum);
523       } else {
524         double  sub_norm = 0.0, norm_sum = 0.0;
525         double *d_array_start;
526 
527         for (CeedInt i = 0; i < num_calls; i++) {
528           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
529           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
530           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
531 
532           CeedCallHipblas(ceed, hipblasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
533           norm_sum += sub_norm * sub_norm;
534         }
535         *norm = sqrt(norm_sum);
536       }
537       break;
538     }
539     case CEED_NORM_MAX: {
540       CeedInt index;
541 
542       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
543         float  sub_max = 0.0, current_max = 0.0;
544         float *d_array_start;
545         for (CeedInt i = 0; i < num_calls; i++) {
546           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
547           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
548           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
549 
550           CeedCallHipblas(ceed, hipblasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &index));
551           CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
552           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
553         }
554         *norm = current_max;
555       } else {
556         double  sub_max = 0.0, current_max = 0.0;
557         double *d_array_start;
558 
559         for (CeedInt i = 0; i < num_calls; i++) {
560           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
561           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
562           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
563 
564           CeedCallHipblas(ceed, hipblasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &index));
565           CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
566           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
567         }
568         *norm = current_max;
569       }
570       break;
571     }
572   }
573   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
574   CeedCallBackend(CeedDestroy(&ceed));
575   return CEED_ERROR_SUCCESS;
576 }
577 
578 //------------------------------------------------------------------------------
579 // Take reciprocal of a vector on host
580 //------------------------------------------------------------------------------
581 static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedSize length) {
582   for (CeedSize i = 0; i < length; i++) {
583     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
584   }
585   return CEED_ERROR_SUCCESS;
586 }
587 
588 //------------------------------------------------------------------------------
589 // Take reciprocal of a vector on device (impl in .hip.cpp file)
590 //------------------------------------------------------------------------------
591 int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedSize length);
592 
593 //------------------------------------------------------------------------------
594 // Take reciprocal of a vector
595 //------------------------------------------------------------------------------
596 static int CeedVectorReciprocal_Hip(CeedVector vec) {
597   CeedSize        length;
598   CeedVector_Hip *impl;
599 
600   CeedCallBackend(CeedVectorGetData(vec, &impl));
601   CeedCallBackend(CeedVectorGetLength(vec, &length));
602   // Set value for synced device/host array
603   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Hip(impl->d_array, length));
604   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length));
605   return CEED_ERROR_SUCCESS;
606 }
607 
608 //------------------------------------------------------------------------------
609 // Compute x = alpha x on the host
610 //------------------------------------------------------------------------------
611 static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length) {
612   for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha;
613   return CEED_ERROR_SUCCESS;
614 }
615 
616 //------------------------------------------------------------------------------
617 // Compute x = alpha x on device (impl in .hip.cpp file)
618 //------------------------------------------------------------------------------
619 int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length);
620 
621 //------------------------------------------------------------------------------
622 // Compute x = alpha x
623 //------------------------------------------------------------------------------
624 static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
625   CeedSize        length;
626   CeedVector_Hip *x_impl;
627 
628   CeedCallBackend(CeedVectorGetData(x, &x_impl));
629   CeedCallBackend(CeedVectorGetLength(x, &length));
630   // Set value for synced device/host array
631   if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Hip(x_impl->d_array, alpha, length));
632   if (x_impl->h_array) CeedCallBackend(CeedHostScale_Hip(x_impl->h_array, alpha, length));
633   return CEED_ERROR_SUCCESS;
634 }
635 
636 //------------------------------------------------------------------------------
637 // Compute y = alpha x + y on the host
638 //------------------------------------------------------------------------------
639 static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) {
640   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
641   return CEED_ERROR_SUCCESS;
642 }
643 
644 //------------------------------------------------------------------------------
645 // Compute y = alpha x + y on device (impl in .hip.cpp file)
646 //------------------------------------------------------------------------------
647 int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length);
648 
649 //------------------------------------------------------------------------------
650 // Compute y = alpha x + y
651 //------------------------------------------------------------------------------
652 static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
653   CeedSize        length;
654   CeedVector_Hip *y_impl, *x_impl;
655 
656   CeedCallBackend(CeedVectorGetData(y, &y_impl));
657   CeedCallBackend(CeedVectorGetData(x, &x_impl));
658   CeedCallBackend(CeedVectorGetLength(y, &length));
659   // Set value for synced device/host array
660   if (y_impl->d_array) {
661     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
662     CeedCallBackend(CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length));
663   }
664   if (y_impl->h_array) {
665     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
666     CeedCallBackend(CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length));
667   }
668   return CEED_ERROR_SUCCESS;
669 }
670 
671 //------------------------------------------------------------------------------
672 // Compute y = alpha x + beta y on the host
673 //------------------------------------------------------------------------------
674 static int CeedHostAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) {
675   for (CeedSize i = 0; i < length; i++) y_array[i] = alpha * x_array[i] + beta * y_array[i];
676   return CEED_ERROR_SUCCESS;
677 }
678 
679 //------------------------------------------------------------------------------
680 // Compute y = alpha x + beta y on device (impl in .hip.cpp file)
681 //------------------------------------------------------------------------------
682 int CeedDeviceAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length);
683 
684 //------------------------------------------------------------------------------
685 // Compute y = alpha x + beta y
686 //------------------------------------------------------------------------------
687 static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
688   CeedSize        length;
689   CeedVector_Hip *y_impl, *x_impl;
690 
691   CeedCallBackend(CeedVectorGetData(y, &y_impl));
692   CeedCallBackend(CeedVectorGetData(x, &x_impl));
693   CeedCallBackend(CeedVectorGetLength(y, &length));
694   // Set value for synced device/host array
695   if (y_impl->d_array) {
696     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
697     CeedCallBackend(CeedDeviceAXPBY_Hip(y_impl->d_array, alpha, beta, x_impl->d_array, length));
698   }
699   if (y_impl->h_array) {
700     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
701     CeedCallBackend(CeedHostAXPBY_Hip(y_impl->h_array, alpha, beta, x_impl->h_array, length));
702   }
703   return CEED_ERROR_SUCCESS;
704 }
705 
706 //------------------------------------------------------------------------------
707 // Compute the pointwise multiplication w = x .* y on the host
708 //------------------------------------------------------------------------------
709 static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) {
710   for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
711   return CEED_ERROR_SUCCESS;
712 }
713 
714 //------------------------------------------------------------------------------
715 // Compute the pointwise multiplication w = x .* y on device (impl in .hip.cpp file)
716 //------------------------------------------------------------------------------
717 int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length);
718 
719 //------------------------------------------------------------------------------
720 // Compute the pointwise multiplication w = x .* y
721 //------------------------------------------------------------------------------
722 static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y) {
723   CeedSize        length;
724   CeedVector_Hip *w_impl, *x_impl, *y_impl;
725 
726   CeedCallBackend(CeedVectorGetData(w, &w_impl));
727   CeedCallBackend(CeedVectorGetData(x, &x_impl));
728   CeedCallBackend(CeedVectorGetData(y, &y_impl));
729   CeedCallBackend(CeedVectorGetLength(w, &length));
730 
731   // Set value for synced device/host array
732   if (!w_impl->d_array && !w_impl->h_array) {
733     CeedCallBackend(CeedVectorSetValue(w, 0.0));
734   }
735   if (w_impl->d_array) {
736     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
737     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
738     CeedCallBackend(CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
739   }
740   if (w_impl->h_array) {
741     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
742     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
743     CeedCallBackend(CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
744   }
745   return CEED_ERROR_SUCCESS;
746 }
747 
748 //------------------------------------------------------------------------------
749 // Destroy the vector
750 //------------------------------------------------------------------------------
751 static int CeedVectorDestroy_Hip(const CeedVector vec) {
752   CeedVector_Hip *impl;
753 
754   CeedCallBackend(CeedVectorGetData(vec, &impl));
755   CeedCallHip(CeedVectorReturnCeed(vec), hipFree(impl->d_array_owned));
756   CeedCallBackend(CeedFree(&impl->h_array_owned));
757   CeedCallBackend(CeedFree(&impl));
758   return CEED_ERROR_SUCCESS;
759 }
760 
761 //------------------------------------------------------------------------------
762 // Create a vector of the specified length (does not allocate memory)
763 //------------------------------------------------------------------------------
764 int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
765   CeedVector_Hip *impl;
766   Ceed            ceed;
767 
768   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
769   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Hip));
770   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Hip));
771   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Hip));
772   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Hip));
773   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "CopyStrided", CeedVectorCopyStrided_Hip));
774   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", CeedVectorSetValue_Hip));
775   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValueStrided", CeedVectorSetValueStrided_Hip));
776   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Hip));
777   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Hip));
778   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Hip));
779   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Hip));
780   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Hip));
781   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Hip));
782   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", CeedVectorScale_Hip));
783   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", CeedVectorAXPY_Hip));
784   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", CeedVectorAXPBY_Hip));
785   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Hip));
786   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Hip));
787   CeedCallBackend(CeedDestroy(&ceed));
788   CeedCallBackend(CeedCalloc(1, &impl));
789   CeedCallBackend(CeedVectorSetData(vec, impl));
790   return CEED_ERROR_SUCCESS;
791 }
792 
793 //------------------------------------------------------------------------------
794