xref: /libCEED/backends/hip-ref/ceed-hip-ref-vector.c (revision c769724092b673807cd0e691b9059edb7d377958)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <math.h>
11 #include <stdbool.h>
12 #include <string.h>
13 #include <hip/hip_runtime.h>
14 
15 #include "../hip/ceed-hip-common.h"
16 #include "ceed-hip-ref.h"
17 
18 //------------------------------------------------------------------------------
19 // Check if host/device sync is needed
20 //------------------------------------------------------------------------------
21 static inline int CeedVectorNeedSync_Hip(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
22   CeedVector_Hip *impl;
23   bool            has_valid_array = false;
24 
25   CeedCallBackend(CeedVectorGetData(vec, &impl));
26   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
27   switch (mem_type) {
28     case CEED_MEM_HOST:
29       *need_sync = has_valid_array && !impl->h_array;
30       break;
31     case CEED_MEM_DEVICE:
32       *need_sync = has_valid_array && !impl->d_array;
33       break;
34   }
35   return CEED_ERROR_SUCCESS;
36 }
37 
38 //------------------------------------------------------------------------------
39 // Sync host to device
40 //------------------------------------------------------------------------------
41 static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) {
42   Ceed            ceed;
43   CeedSize        length;
44   size_t          bytes;
45   CeedVector_Hip *impl;
46 
47   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
48   CeedCallBackend(CeedVectorGetData(vec, &impl));
49 
50   CeedCheck(impl->h_array, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
51 
52   CeedCallBackend(CeedVectorGetLength(vec, &length));
53   bytes = length * sizeof(CeedScalar);
54   if (impl->d_array_borrowed) {
55     impl->d_array = impl->d_array_borrowed;
56   } else if (impl->d_array_owned) {
57     impl->d_array = impl->d_array_owned;
58   } else {
59     CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes));
60     impl->d_array = impl->d_array_owned;
61   }
62   CeedCallHip(ceed, hipMemcpy(impl->d_array, impl->h_array, bytes, hipMemcpyHostToDevice));
63   return CEED_ERROR_SUCCESS;
64 }
65 
66 //------------------------------------------------------------------------------
67 // Sync device to host
68 //------------------------------------------------------------------------------
69 static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
70   Ceed            ceed;
71   CeedSize        length;
72   size_t          bytes;
73   CeedVector_Hip *impl;
74 
75   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
76   CeedCallBackend(CeedVectorGetData(vec, &impl));
77 
78   CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
79 
80   if (impl->h_array_borrowed) {
81     impl->h_array = impl->h_array_borrowed;
82   } else if (impl->h_array_owned) {
83     impl->h_array = impl->h_array_owned;
84   } else {
85     CeedSize length;
86 
87     CeedCallBackend(CeedVectorGetLength(vec, &length));
88     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
89     impl->h_array = impl->h_array_owned;
90   }
91 
92   CeedCallBackend(CeedVectorGetLength(vec, &length));
93   bytes = length * sizeof(CeedScalar);
94   CeedCallHip(ceed, hipMemcpy(impl->h_array, impl->d_array, bytes, hipMemcpyDeviceToHost));
95   return CEED_ERROR_SUCCESS;
96 }
97 
98 //------------------------------------------------------------------------------
99 // Sync arrays
100 //------------------------------------------------------------------------------
101 static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) {
102   bool need_sync = false;
103 
104   // Check whether device/host sync is needed
105   CeedCallBackend(CeedVectorNeedSync_Hip(vec, mem_type, &need_sync));
106   if (!need_sync) return CEED_ERROR_SUCCESS;
107 
108   switch (mem_type) {
109     case CEED_MEM_HOST:
110       return CeedVectorSyncD2H_Hip(vec);
111     case CEED_MEM_DEVICE:
112       return CeedVectorSyncH2D_Hip(vec);
113   }
114   return CEED_ERROR_UNSUPPORTED;
115 }
116 
117 //------------------------------------------------------------------------------
118 // Set all pointers as invalid
119 //------------------------------------------------------------------------------
120 static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) {
121   CeedVector_Hip *impl;
122 
123   CeedCallBackend(CeedVectorGetData(vec, &impl));
124   impl->h_array = NULL;
125   impl->d_array = NULL;
126   return CEED_ERROR_SUCCESS;
127 }
128 
129 //------------------------------------------------------------------------------
130 // Check if CeedVector has any valid pointer
131 //------------------------------------------------------------------------------
132 static inline int CeedVectorHasValidArray_Hip(const CeedVector vec, bool *has_valid_array) {
133   CeedVector_Hip *impl;
134 
135   CeedCallBackend(CeedVectorGetData(vec, &impl));
136   *has_valid_array = impl->h_array || impl->d_array;
137   return CEED_ERROR_SUCCESS;
138 }
139 
140 //------------------------------------------------------------------------------
141 // Check if has array of given type
142 //------------------------------------------------------------------------------
143 static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
144   CeedVector_Hip *impl;
145 
146   CeedCallBackend(CeedVectorGetData(vec, &impl));
147   switch (mem_type) {
148     case CEED_MEM_HOST:
149       *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned;
150       break;
151     case CEED_MEM_DEVICE:
152       *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned;
153       break;
154   }
155   return CEED_ERROR_SUCCESS;
156 }
157 
158 //------------------------------------------------------------------------------
159 // Check if has borrowed array of given type
160 //------------------------------------------------------------------------------
161 static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
162   CeedVector_Hip *impl;
163 
164   CeedCallBackend(CeedVectorGetData(vec, &impl));
165   switch (mem_type) {
166     case CEED_MEM_HOST:
167       *has_borrowed_array_of_type = impl->h_array_borrowed;
168       break;
169     case CEED_MEM_DEVICE:
170       *has_borrowed_array_of_type = impl->d_array_borrowed;
171       break;
172   }
173   return CEED_ERROR_SUCCESS;
174 }
175 
176 //------------------------------------------------------------------------------
177 // Set array from host
178 //------------------------------------------------------------------------------
179 static int CeedVectorSetArrayHost_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
180   CeedVector_Hip *impl;
181 
182   CeedCallBackend(CeedVectorGetData(vec, &impl));
183   switch (copy_mode) {
184     case CEED_COPY_VALUES: {
185       CeedSize length;
186 
187       if (!impl->h_array_owned) {
188         CeedCallBackend(CeedVectorGetLength(vec, &length));
189         CeedCallBackend(CeedMalloc(length, &impl->h_array_owned));
190       }
191       impl->h_array_borrowed = NULL;
192       impl->h_array          = impl->h_array_owned;
193       if (array) {
194         CeedSize length;
195         size_t   bytes;
196 
197         CeedCallBackend(CeedVectorGetLength(vec, &length));
198         bytes = length * sizeof(CeedScalar);
199         memcpy(impl->h_array, array, bytes);
200       }
201     } break;
202     case CEED_OWN_POINTER:
203       CeedCallBackend(CeedFree(&impl->h_array_owned));
204       impl->h_array_owned    = array;
205       impl->h_array_borrowed = NULL;
206       impl->h_array          = array;
207       break;
208     case CEED_USE_POINTER:
209       CeedCallBackend(CeedFree(&impl->h_array_owned));
210       impl->h_array_borrowed = array;
211       impl->h_array          = array;
212       break;
213   }
214   return CEED_ERROR_SUCCESS;
215 }
216 
217 //------------------------------------------------------------------------------
218 // Set array from device
219 //------------------------------------------------------------------------------
220 static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
221   Ceed            ceed;
222   CeedVector_Hip *impl;
223 
224   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
225   CeedCallBackend(CeedVectorGetData(vec, &impl));
226   switch (copy_mode) {
227     case CEED_COPY_VALUES: {
228       CeedSize length;
229       size_t   bytes;
230 
231       CeedCallBackend(CeedVectorGetLength(vec, &length));
232       bytes = length * sizeof(CeedScalar);
233       if (!impl->d_array_owned) {
234         CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes));
235       }
236       impl->d_array_borrowed = NULL;
237       impl->d_array          = impl->d_array_owned;
238       if (array) CeedCallHip(ceed, hipMemcpy(impl->d_array, array, bytes, hipMemcpyDeviceToDevice));
239     } break;
240     case CEED_OWN_POINTER:
241       CeedCallHip(ceed, hipFree(impl->d_array_owned));
242       impl->d_array_owned    = array;
243       impl->d_array_borrowed = NULL;
244       impl->d_array          = array;
245       break;
246     case CEED_USE_POINTER:
247       CeedCallHip(ceed, hipFree(impl->d_array_owned));
248       impl->d_array_owned    = NULL;
249       impl->d_array_borrowed = array;
250       impl->d_array          = array;
251       break;
252   }
253   return CEED_ERROR_SUCCESS;
254 }
255 
256 //------------------------------------------------------------------------------
257 // Set the array used by a vector,
258 //   freeing any previously allocated array if applicable
259 //------------------------------------------------------------------------------
260 static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
261   Ceed            ceed;
262   CeedVector_Hip *impl;
263 
264   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
265   CeedCallBackend(CeedVectorGetData(vec, &impl));
266   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
267   switch (mem_type) {
268     case CEED_MEM_HOST:
269       return CeedVectorSetArrayHost_Hip(vec, copy_mode, array);
270     case CEED_MEM_DEVICE:
271       return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array);
272   }
273   return CEED_ERROR_UNSUPPORTED;
274 }
275 
276 //------------------------------------------------------------------------------
277 // Set host array to value
278 //------------------------------------------------------------------------------
279 static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedSize length, CeedScalar val) {
280   for (CeedSize i = 0; i < length; i++) h_array[i] = val;
281   return CEED_ERROR_SUCCESS;
282 }
283 
284 //------------------------------------------------------------------------------
285 // Set device array to value (impl in .hip file)
286 //------------------------------------------------------------------------------
287 int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val);
288 
289 //------------------------------------------------------------------------------
290 // Set a vector to a value
291 //------------------------------------------------------------------------------
292 static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
293   Ceed            ceed;
294   CeedSize        length;
295   CeedVector_Hip *impl;
296 
297   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
298   CeedCallBackend(CeedVectorGetData(vec, &impl));
299   CeedCallBackend(CeedVectorGetLength(vec, &length));
300   // Set value for synced device/host array
301   if (!impl->d_array && !impl->h_array) {
302     if (impl->d_array_borrowed) {
303       impl->d_array = impl->d_array_borrowed;
304     } else if (impl->h_array_borrowed) {
305       impl->h_array = impl->h_array_borrowed;
306     } else if (impl->d_array_owned) {
307       impl->d_array = impl->d_array_owned;
308     } else if (impl->h_array_owned) {
309       impl->h_array = impl->h_array_owned;
310     } else {
311       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
312     }
313   }
314   if (impl->d_array) {
315     CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val));
316     impl->h_array = NULL;
317   }
318   if (impl->h_array) {
319     CeedCallBackend(CeedHostSetValue_Hip(impl->h_array, length, val));
320     impl->d_array = NULL;
321   }
322   return CEED_ERROR_SUCCESS;
323 }
324 
325 //------------------------------------------------------------------------------
326 // Vector Take Array
327 //------------------------------------------------------------------------------
328 static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
329   Ceed            ceed;
330   CeedVector_Hip *impl;
331 
332   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
333   CeedCallBackend(CeedVectorGetData(vec, &impl));
334 
335   // Sync array to requested mem_type
336   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
337 
338   // Update pointer
339   switch (mem_type) {
340     case CEED_MEM_HOST:
341       (*array)               = impl->h_array_borrowed;
342       impl->h_array_borrowed = NULL;
343       impl->h_array          = NULL;
344       break;
345     case CEED_MEM_DEVICE:
346       (*array)               = impl->d_array_borrowed;
347       impl->d_array_borrowed = NULL;
348       impl->d_array          = NULL;
349       break;
350   }
351   return CEED_ERROR_SUCCESS;
352 }
353 
354 //------------------------------------------------------------------------------
355 // Core logic for array syncronization for GetArray.
356 //   If a different memory type is most up to date, this will perform a copy
357 //------------------------------------------------------------------------------
358 static int CeedVectorGetArrayCore_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
359   Ceed            ceed;
360   CeedVector_Hip *impl;
361 
362   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
363   CeedCallBackend(CeedVectorGetData(vec, &impl));
364 
365   // Sync array to requested mem_type
366   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
367 
368   // Update pointer
369   switch (mem_type) {
370     case CEED_MEM_HOST:
371       *array = impl->h_array;
372       break;
373     case CEED_MEM_DEVICE:
374       *array = impl->d_array;
375       break;
376   }
377   return CEED_ERROR_SUCCESS;
378 }
379 
380 //------------------------------------------------------------------------------
381 // Get read-only access to a vector via the specified mem_type
382 //------------------------------------------------------------------------------
383 static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
384   return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array);
385 }
386 
387 //------------------------------------------------------------------------------
388 // Get read/write access to a vector via the specified mem_type
389 //------------------------------------------------------------------------------
390 static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
391   CeedVector_Hip *impl;
392 
393   CeedCallBackend(CeedVectorGetData(vec, &impl));
394   CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, mem_type, array));
395   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
396   switch (mem_type) {
397     case CEED_MEM_HOST:
398       impl->h_array = *array;
399       break;
400     case CEED_MEM_DEVICE:
401       impl->d_array = *array;
402       break;
403   }
404   return CEED_ERROR_SUCCESS;
405 }
406 
407 //------------------------------------------------------------------------------
408 // Get write access to a vector via the specified mem_type
409 //------------------------------------------------------------------------------
410 static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
411   bool            has_array_of_type = true;
412   CeedVector_Hip *impl;
413 
414   CeedCallBackend(CeedVectorGetData(vec, &impl));
415   CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type));
416   if (!has_array_of_type) {
417     // Allocate if array is not yet allocated
418     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
419   } else {
420     // Select dirty array
421     switch (mem_type) {
422       case CEED_MEM_HOST:
423         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
424         else impl->h_array = impl->h_array_owned;
425         break;
426       case CEED_MEM_DEVICE:
427         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
428         else impl->d_array = impl->d_array_owned;
429     }
430   }
431   return CeedVectorGetArray_Hip(vec, mem_type, array);
432 }
433 
434 //------------------------------------------------------------------------------
435 // Get the norm of a CeedVector
436 //------------------------------------------------------------------------------
437 static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *norm) {
438   Ceed              ceed;
439   CeedSize          length, num_calls;
440   const CeedScalar *d_array;
441   CeedVector_Hip   *impl;
442   hipblasHandle_t   handle;
443 
444   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
445   CeedCallBackend(CeedVectorGetData(vec, &impl));
446   CeedCallBackend(CeedVectorGetLength(vec, &length));
447   CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
448 
449   // Is the vector too long to handle with int32? If so, we will divide
450   // it up into "int32-sized" subsections and make repeated BLAS calls.
451   num_calls = length / INT_MAX;
452   if (length % INT_MAX > 0) num_calls += 1;
453 
454   // Compute norm
455   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
456   switch (type) {
457     case CEED_NORM_1: {
458       *norm = 0.0;
459       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
460         float  sub_norm = 0.0;
461         float *d_array_start;
462 
463         for (CeedInt i = 0; i < num_calls; i++) {
464           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
465           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
466           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
467 
468           CeedCallHipblas(ceed, hipblasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
469           *norm += sub_norm;
470         }
471       } else {
472         double  sub_norm = 0.0;
473         double *d_array_start;
474 
475         for (CeedInt i = 0; i < num_calls; i++) {
476           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
477           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
478           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
479 
480           CeedCallHipblas(ceed, hipblasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
481           *norm += sub_norm;
482         }
483       }
484       break;
485     }
486     case CEED_NORM_2: {
487       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
488         float  sub_norm = 0.0, norm_sum = 0.0;
489         float *d_array_start;
490 
491         for (CeedInt i = 0; i < num_calls; i++) {
492           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
493           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
494           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
495 
496           CeedCallHipblas(ceed, hipblasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
497           norm_sum += sub_norm * sub_norm;
498         }
499         *norm = sqrt(norm_sum);
500       } else {
501         double  sub_norm = 0.0, norm_sum = 0.0;
502         double *d_array_start;
503 
504         for (CeedInt i = 0; i < num_calls; i++) {
505           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
506           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
507           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
508 
509           CeedCallHipblas(ceed, hipblasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
510           norm_sum += sub_norm * sub_norm;
511         }
512         *norm = sqrt(norm_sum);
513       }
514       break;
515     }
516     case CEED_NORM_MAX: {
517       CeedInt index;
518 
519       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
520         float  sub_max = 0.0, current_max = 0.0;
521         float *d_array_start;
522         for (CeedInt i = 0; i < num_calls; i++) {
523           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
524           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
525           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
526 
527           CeedCallHipblas(ceed, hipblasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &index));
528           CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
529           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
530         }
531         *norm = current_max;
532       } else {
533         double  sub_max = 0.0, current_max = 0.0;
534         double *d_array_start;
535 
536         for (CeedInt i = 0; i < num_calls; i++) {
537           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
538           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
539           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
540 
541           CeedCallHipblas(ceed, hipblasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &index));
542           CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
543           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
544         }
545         *norm = current_max;
546       }
547       break;
548     }
549   }
550   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
551   return CEED_ERROR_SUCCESS;
552 }
553 
554 //------------------------------------------------------------------------------
555 // Take reciprocal of a vector on host
556 //------------------------------------------------------------------------------
557 static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedSize length) {
558   for (CeedSize i = 0; i < length; i++) {
559     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
560   }
561   return CEED_ERROR_SUCCESS;
562 }
563 
564 //------------------------------------------------------------------------------
565 // Take reciprocal of a vector on device (impl in .cu file)
566 //------------------------------------------------------------------------------
567 int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedSize length);
568 
569 //------------------------------------------------------------------------------
570 // Take reciprocal of a vector
571 //------------------------------------------------------------------------------
572 static int CeedVectorReciprocal_Hip(CeedVector vec) {
573   Ceed            ceed;
574   CeedSize        length;
575   CeedVector_Hip *impl;
576 
577   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
578   CeedCallBackend(CeedVectorGetData(vec, &impl));
579   CeedCallBackend(CeedVectorGetLength(vec, &length));
580   // Set value for synced device/host array
581   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Hip(impl->d_array, length));
582   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length));
583   return CEED_ERROR_SUCCESS;
584 }
585 
586 //------------------------------------------------------------------------------
587 // Compute x = alpha x on the host
588 //------------------------------------------------------------------------------
589 static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length) {
590   for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha;
591   return CEED_ERROR_SUCCESS;
592 }
593 
594 //------------------------------------------------------------------------------
595 // Compute x = alpha x on device (impl in .cu file)
596 //------------------------------------------------------------------------------
597 int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length);
598 
599 //------------------------------------------------------------------------------
600 // Compute x = alpha x
601 //------------------------------------------------------------------------------
602 static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
603   Ceed            ceed;
604   CeedSize        length;
605   CeedVector_Hip *x_impl;
606 
607   CeedCallBackend(CeedVectorGetCeed(x, &ceed));
608   CeedCallBackend(CeedVectorGetData(x, &x_impl));
609   CeedCallBackend(CeedVectorGetLength(x, &length));
610   // Set value for synced device/host array
611   if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Hip(x_impl->d_array, alpha, length));
612   if (x_impl->h_array) CeedCallBackend(CeedHostScale_Hip(x_impl->h_array, alpha, length));
613   return CEED_ERROR_SUCCESS;
614 }
615 
616 //------------------------------------------------------------------------------
617 // Compute y = alpha x + y on the host
618 //------------------------------------------------------------------------------
619 static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) {
620   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
621   return CEED_ERROR_SUCCESS;
622 }
623 
624 //------------------------------------------------------------------------------
625 // Compute y = alpha x + y on device (impl in .cu file)
626 //------------------------------------------------------------------------------
627 int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length);
628 
629 //------------------------------------------------------------------------------
630 // Compute y = alpha x + y
631 //------------------------------------------------------------------------------
632 static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
633   Ceed            ceed;
634   CeedSize        length;
635   CeedVector_Hip *y_impl, *x_impl;
636 
637   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
638   CeedCallBackend(CeedVectorGetData(y, &y_impl));
639   CeedCallBackend(CeedVectorGetData(x, &x_impl));
640   CeedCallBackend(CeedVectorGetLength(y, &length));
641   // Set value for synced device/host array
642   if (y_impl->d_array) {
643     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
644     CeedCallBackend(CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length));
645   }
646   if (y_impl->h_array) {
647     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
648     CeedCallBackend(CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length));
649   }
650   return CEED_ERROR_SUCCESS;
651 }
652 
653 //------------------------------------------------------------------------------
654 // Compute y = alpha x + beta y on the host
655 //------------------------------------------------------------------------------
656 static int CeedHostAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) {
657   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i] + beta * y_array[i];
658   return CEED_ERROR_SUCCESS;
659 }
660 
661 //------------------------------------------------------------------------------
662 // Compute y = alpha x + beta y on device (impl in .cu file)
663 //------------------------------------------------------------------------------
664 int CeedDeviceAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length);
665 
666 //------------------------------------------------------------------------------
667 // Compute y = alpha x + beta y
668 //------------------------------------------------------------------------------
669 static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
670   Ceed            ceed;
671   CeedSize        length;
672   CeedVector_Hip *y_impl, *x_impl;
673 
674   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
675   CeedCallBackend(CeedVectorGetData(y, &y_impl));
676   CeedCallBackend(CeedVectorGetData(x, &x_impl));
677   CeedCallBackend(CeedVectorGetLength(y, &length));
678   // Set value for synced device/host array
679   if (y_impl->d_array) {
680     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
681     CeedCallBackend(CeedDeviceAXPBY_Hip(y_impl->d_array, alpha, beta, x_impl->d_array, length));
682   }
683   if (y_impl->h_array) {
684     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
685     CeedCallBackend(CeedHostAXPBY_Hip(y_impl->h_array, alpha, beta, x_impl->h_array, length));
686   }
687   return CEED_ERROR_SUCCESS;
688 }
689 
690 //------------------------------------------------------------------------------
691 // Compute the pointwise multiplication w = x .* y on the host
692 //------------------------------------------------------------------------------
693 static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) {
694   for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
695   return CEED_ERROR_SUCCESS;
696 }
697 
698 //------------------------------------------------------------------------------
699 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
700 //------------------------------------------------------------------------------
701 int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length);
702 
703 //------------------------------------------------------------------------------
704 // Compute the pointwise multiplication w = x .* y
705 //------------------------------------------------------------------------------
706 static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y) {
707   Ceed            ceed;
708   CeedSize        length;
709   CeedVector_Hip *w_impl, *x_impl, *y_impl;
710 
711   CeedCallBackend(CeedVectorGetCeed(w, &ceed));
712   CeedCallBackend(CeedVectorGetData(w, &w_impl));
713   CeedCallBackend(CeedVectorGetData(x, &x_impl));
714   CeedCallBackend(CeedVectorGetData(y, &y_impl));
715   CeedCallBackend(CeedVectorGetLength(w, &length));
716 
717   // Set value for synced device/host array
718   if (!w_impl->d_array && !w_impl->h_array) {
719     CeedCallBackend(CeedVectorSetValue(w, 0.0));
720   }
721   if (w_impl->d_array) {
722     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
723     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
724     CeedCallBackend(CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
725   }
726   if (w_impl->h_array) {
727     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
728     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
729     CeedCallBackend(CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
730   }
731   return CEED_ERROR_SUCCESS;
732 }
733 
734 //------------------------------------------------------------------------------
735 // Destroy the vector
736 //------------------------------------------------------------------------------
737 static int CeedVectorDestroy_Hip(const CeedVector vec) {
738   Ceed            ceed;
739   CeedVector_Hip *impl;
740 
741   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
742   CeedCallBackend(CeedVectorGetData(vec, &impl));
743   CeedCallHip(ceed, hipFree(impl->d_array_owned));
744   CeedCallBackend(CeedFree(&impl->h_array_owned));
745   CeedCallBackend(CeedFree(&impl));
746   return CEED_ERROR_SUCCESS;
747 }
748 
749 //------------------------------------------------------------------------------
750 // Create a vector of the specified length (does not allocate memory)
751 //------------------------------------------------------------------------------
752 int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
753   CeedVector_Hip *impl;
754   Ceed            ceed;
755 
756   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
757   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Hip));
758   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Hip));
759   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Hip));
760   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Hip));
761   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", (int (*)())CeedVectorSetValue_Hip));
762   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Hip));
763   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Hip));
764   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Hip));
765   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Hip));
766   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Hip));
767   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Hip));
768   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", (int (*)())CeedVectorScale_Hip));
769   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", (int (*)())CeedVectorAXPY_Hip));
770   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", (int (*)())CeedVectorAXPBY_Hip));
771   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Hip));
772   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Hip));
773   CeedCallBackend(CeedCalloc(1, &impl));
774   CeedCallBackend(CeedVectorSetData(vec, impl));
775   return CEED_ERROR_SUCCESS;
776 }
777 
778 //------------------------------------------------------------------------------
779