xref: /libCEED/backends/hip-ref/ceed-hip-ref-vector.c (revision edb2538e3dd6743c029967fc4e89c6fcafedb8c2)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <math.h>
11 #include <stdbool.h>
12 #include <string.h>
13 #include <hip/hip_runtime.h>
14 
15 #include "../hip/ceed-hip-common.h"
16 #include "ceed-hip-ref.h"
17 
18 //------------------------------------------------------------------------------
19 // Check if host/device sync is needed
20 //------------------------------------------------------------------------------
21 static inline int CeedVectorNeedSync_Hip(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
22   CeedVector_Hip *impl;
23   bool            has_valid_array = false;
24 
25   CeedCallBackend(CeedVectorGetData(vec, &impl));
26   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
27   switch (mem_type) {
28     case CEED_MEM_HOST:
29       *need_sync = has_valid_array && !impl->h_array;
30       break;
31     case CEED_MEM_DEVICE:
32       *need_sync = has_valid_array && !impl->d_array;
33       break;
34   }
35   return CEED_ERROR_SUCCESS;
36 }
37 
38 //------------------------------------------------------------------------------
39 // Sync host to device
40 //------------------------------------------------------------------------------
41 static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) {
42   Ceed            ceed;
43   CeedSize        length;
44   CeedVector_Hip *impl;
45 
46   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
47   CeedCallBackend(CeedVectorGetData(vec, &impl));
48 
49   CeedCallBackend(CeedVectorGetLength(vec, &length));
50   size_t bytes = length * sizeof(CeedScalar);
51 
52   CeedCheck(impl->h_array, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
53 
54   if (impl->d_array_borrowed) {
55     impl->d_array = impl->d_array_borrowed;
56   } else if (impl->d_array_owned) {
57     impl->d_array = impl->d_array_owned;
58   } else {
59     CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes));
60     impl->d_array = impl->d_array_owned;
61   }
62   CeedCallHip(ceed, hipMemcpy(impl->d_array, impl->h_array, bytes, hipMemcpyHostToDevice));
63   return CEED_ERROR_SUCCESS;
64 }
65 
66 //------------------------------------------------------------------------------
67 // Sync device to host
68 //------------------------------------------------------------------------------
69 static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
70   Ceed            ceed;
71   CeedSize        length;
72   CeedVector_Hip *impl;
73 
74   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
75   CeedCallBackend(CeedVectorGetData(vec, &impl));
76 
77   CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
78 
79   if (impl->h_array_borrowed) {
80     impl->h_array = impl->h_array_borrowed;
81   } else if (impl->h_array_owned) {
82     impl->h_array = impl->h_array_owned;
83   } else {
84     CeedSize length;
85     CeedCallBackend(CeedVectorGetLength(vec, &length));
86     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
87     impl->h_array = impl->h_array_owned;
88   }
89 
90   CeedCallBackend(CeedVectorGetLength(vec, &length));
91   size_t bytes = length * sizeof(CeedScalar);
92 
93   CeedCallHip(ceed, hipMemcpy(impl->h_array, impl->d_array, bytes, hipMemcpyDeviceToHost));
94   return CEED_ERROR_SUCCESS;
95 }
96 
97 //------------------------------------------------------------------------------
98 // Sync arrays
99 //------------------------------------------------------------------------------
100 static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) {
101   bool need_sync = false;
102 
103   // Check whether device/host sync is needed
104   CeedCallBackend(CeedVectorNeedSync_Hip(vec, mem_type, &need_sync));
105   if (!need_sync) return CEED_ERROR_SUCCESS;
106 
107   switch (mem_type) {
108     case CEED_MEM_HOST:
109       return CeedVectorSyncD2H_Hip(vec);
110     case CEED_MEM_DEVICE:
111       return CeedVectorSyncH2D_Hip(vec);
112   }
113   return CEED_ERROR_UNSUPPORTED;
114 }
115 
116 //------------------------------------------------------------------------------
117 // Set all pointers as invalid
118 //------------------------------------------------------------------------------
119 static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) {
120   CeedVector_Hip *impl;
121 
122   CeedCallBackend(CeedVectorGetData(vec, &impl));
123   impl->h_array = NULL;
124   impl->d_array = NULL;
125   return CEED_ERROR_SUCCESS;
126 }
127 
128 //------------------------------------------------------------------------------
129 // Check if CeedVector has any valid pointer
130 //------------------------------------------------------------------------------
131 static inline int CeedVectorHasValidArray_Hip(const CeedVector vec, bool *has_valid_array) {
132   CeedVector_Hip *impl;
133 
134   CeedCallBackend(CeedVectorGetData(vec, &impl));
135   *has_valid_array = impl->h_array || impl->d_array;
136   return CEED_ERROR_SUCCESS;
137 }
138 
139 //------------------------------------------------------------------------------
140 // Check if has array of given type
141 //------------------------------------------------------------------------------
142 static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
143   CeedVector_Hip *impl;
144 
145   CeedCallBackend(CeedVectorGetData(vec, &impl));
146   switch (mem_type) {
147     case CEED_MEM_HOST:
148       *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned;
149       break;
150     case CEED_MEM_DEVICE:
151       *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned;
152       break;
153   }
154   return CEED_ERROR_SUCCESS;
155 }
156 
157 //------------------------------------------------------------------------------
158 // Check if has borrowed array of given type
159 //------------------------------------------------------------------------------
160 static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
161   CeedVector_Hip *impl;
162 
163   CeedCallBackend(CeedVectorGetData(vec, &impl));
164   switch (mem_type) {
165     case CEED_MEM_HOST:
166       *has_borrowed_array_of_type = impl->h_array_borrowed;
167       break;
168     case CEED_MEM_DEVICE:
169       *has_borrowed_array_of_type = impl->d_array_borrowed;
170       break;
171   }
172   return CEED_ERROR_SUCCESS;
173 }
174 
175 //------------------------------------------------------------------------------
176 // Set array from host
177 //------------------------------------------------------------------------------
178 static int CeedVectorSetArrayHost_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
179   CeedVector_Hip *impl;
180 
181   CeedCallBackend(CeedVectorGetData(vec, &impl));
182   switch (copy_mode) {
183     case CEED_COPY_VALUES: {
184       CeedSize length;
185 
186       if (!impl->h_array_owned) {
187         CeedCallBackend(CeedVectorGetLength(vec, &length));
188         CeedCallBackend(CeedMalloc(length, &impl->h_array_owned));
189       }
190       impl->h_array_borrowed = NULL;
191       impl->h_array          = impl->h_array_owned;
192       if (array) {
193         CeedSize length;
194 
195         CeedCallBackend(CeedVectorGetLength(vec, &length));
196         size_t bytes = length * sizeof(CeedScalar);
197         memcpy(impl->h_array, array, bytes);
198       }
199     } break;
200     case CEED_OWN_POINTER:
201       CeedCallBackend(CeedFree(&impl->h_array_owned));
202       impl->h_array_owned    = array;
203       impl->h_array_borrowed = NULL;
204       impl->h_array          = array;
205       break;
206     case CEED_USE_POINTER:
207       CeedCallBackend(CeedFree(&impl->h_array_owned));
208       impl->h_array_borrowed = array;
209       impl->h_array          = array;
210       break;
211   }
212   return CEED_ERROR_SUCCESS;
213 }
214 
215 //------------------------------------------------------------------------------
216 // Set array from device
217 //------------------------------------------------------------------------------
218 static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
219   Ceed            ceed;
220   CeedVector_Hip *impl;
221 
222   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
223   CeedCallBackend(CeedVectorGetData(vec, &impl));
224   switch (copy_mode) {
225     case CEED_COPY_VALUES: {
226       CeedSize length;
227 
228       CeedCallBackend(CeedVectorGetLength(vec, &length));
229       size_t bytes = length * sizeof(CeedScalar);
230 
231       if (!impl->d_array_owned) {
232         CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes));
233       }
234       impl->d_array_borrowed = NULL;
235       impl->d_array          = impl->d_array_owned;
236       if (array) CeedCallHip(ceed, hipMemcpy(impl->d_array, array, bytes, hipMemcpyDeviceToDevice));
237     } break;
238     case CEED_OWN_POINTER:
239       CeedCallHip(ceed, hipFree(impl->d_array_owned));
240       impl->d_array_owned    = array;
241       impl->d_array_borrowed = NULL;
242       impl->d_array          = array;
243       break;
244     case CEED_USE_POINTER:
245       CeedCallHip(ceed, hipFree(impl->d_array_owned));
246       impl->d_array_owned    = NULL;
247       impl->d_array_borrowed = array;
248       impl->d_array          = array;
249       break;
250   }
251   return CEED_ERROR_SUCCESS;
252 }
253 
254 //------------------------------------------------------------------------------
255 // Set the array used by a vector,
256 //   freeing any previously allocated array if applicable
257 //------------------------------------------------------------------------------
258 static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
259   Ceed            ceed;
260   CeedVector_Hip *impl;
261 
262   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
263   CeedCallBackend(CeedVectorGetData(vec, &impl));
264   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
265   switch (mem_type) {
266     case CEED_MEM_HOST:
267       return CeedVectorSetArrayHost_Hip(vec, copy_mode, array);
268     case CEED_MEM_DEVICE:
269       return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array);
270   }
271   return CEED_ERROR_UNSUPPORTED;
272 }
273 
274 //------------------------------------------------------------------------------
275 // Set host array to value
276 //------------------------------------------------------------------------------
277 static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedSize length, CeedScalar val) {
278   for (CeedSize i = 0; i < length; i++) h_array[i] = val;
279   return CEED_ERROR_SUCCESS;
280 }
281 
282 //------------------------------------------------------------------------------
283 // Set device array to value (impl in .hip file)
284 //------------------------------------------------------------------------------
285 int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val);
286 
287 //------------------------------------------------------------------------------
288 // Set a vector to a value
289 //------------------------------------------------------------------------------
290 static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
291   Ceed            ceed;
292   CeedSize        length;
293   CeedVector_Hip *impl;
294 
295   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
296   CeedCallBackend(CeedVectorGetData(vec, &impl));
297   CeedCallBackend(CeedVectorGetLength(vec, &length));
298   // Set value for synced device/host array
299   if (!impl->d_array && !impl->h_array) {
300     if (impl->d_array_borrowed) {
301       impl->d_array = impl->d_array_borrowed;
302     } else if (impl->h_array_borrowed) {
303       impl->h_array = impl->h_array_borrowed;
304     } else if (impl->d_array_owned) {
305       impl->d_array = impl->d_array_owned;
306     } else if (impl->h_array_owned) {
307       impl->h_array = impl->h_array_owned;
308     } else {
309       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
310     }
311   }
312   if (impl->d_array) {
313     CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val));
314     impl->h_array = NULL;
315   }
316   if (impl->h_array) {
317     CeedCallBackend(CeedHostSetValue_Hip(impl->h_array, length, val));
318     impl->d_array = NULL;
319   }
320   return CEED_ERROR_SUCCESS;
321 }
322 
323 //------------------------------------------------------------------------------
324 // Vector Take Array
325 //------------------------------------------------------------------------------
326 static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
327   Ceed            ceed;
328   CeedVector_Hip *impl;
329 
330   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
331   CeedCallBackend(CeedVectorGetData(vec, &impl));
332 
333   // Sync array to requested mem_type
334   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
335 
336   // Update pointer
337   switch (mem_type) {
338     case CEED_MEM_HOST:
339       (*array)               = impl->h_array_borrowed;
340       impl->h_array_borrowed = NULL;
341       impl->h_array          = NULL;
342       break;
343     case CEED_MEM_DEVICE:
344       (*array)               = impl->d_array_borrowed;
345       impl->d_array_borrowed = NULL;
346       impl->d_array          = NULL;
347       break;
348   }
349   return CEED_ERROR_SUCCESS;
350 }
351 
352 //------------------------------------------------------------------------------
353 // Core logic for array syncronization for GetArray.
354 //   If a different memory type is most up to date, this will perform a copy
355 //------------------------------------------------------------------------------
356 static int CeedVectorGetArrayCore_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
357   Ceed            ceed;
358   CeedVector_Hip *impl;
359 
360   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
361   CeedCallBackend(CeedVectorGetData(vec, &impl));
362 
363   // Sync array to requested mem_type
364   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
365 
366   // Update pointer
367   switch (mem_type) {
368     case CEED_MEM_HOST:
369       *array = impl->h_array;
370       break;
371     case CEED_MEM_DEVICE:
372       *array = impl->d_array;
373       break;
374   }
375   return CEED_ERROR_SUCCESS;
376 }
377 
378 //------------------------------------------------------------------------------
379 // Get read-only access to a vector via the specified mem_type
380 //------------------------------------------------------------------------------
381 static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
382   return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array);
383 }
384 
385 //------------------------------------------------------------------------------
386 // Get read/write access to a vector via the specified mem_type
387 //------------------------------------------------------------------------------
388 static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
389   CeedVector_Hip *impl;
390 
391   CeedCallBackend(CeedVectorGetData(vec, &impl));
392   CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, mem_type, array));
393   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
394   switch (mem_type) {
395     case CEED_MEM_HOST:
396       impl->h_array = *array;
397       break;
398     case CEED_MEM_DEVICE:
399       impl->d_array = *array;
400       break;
401   }
402   return CEED_ERROR_SUCCESS;
403 }
404 
405 //------------------------------------------------------------------------------
406 // Get write access to a vector via the specified mem_type
407 //------------------------------------------------------------------------------
408 static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
409   bool            has_array_of_type = true;
410   CeedVector_Hip *impl;
411 
412   CeedCallBackend(CeedVectorGetData(vec, &impl));
413   CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type));
414   if (!has_array_of_type) {
415     // Allocate if array is not yet allocated
416     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
417   } else {
418     // Select dirty array
419     switch (mem_type) {
420       case CEED_MEM_HOST:
421         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
422         else impl->h_array = impl->h_array_owned;
423         break;
424       case CEED_MEM_DEVICE:
425         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
426         else impl->d_array = impl->d_array_owned;
427     }
428   }
429   return CeedVectorGetArray_Hip(vec, mem_type, array);
430 }
431 
432 //------------------------------------------------------------------------------
433 // Get the norm of a CeedVector
434 //------------------------------------------------------------------------------
435 static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *norm) {
436   Ceed              ceed;
437   CeedSize          length;
438   const CeedScalar *d_array;
439   CeedVector_Hip   *impl;
440   hipblasHandle_t   handle;
441 
442   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
443   CeedCallBackend(CeedVectorGetData(vec, &impl));
444   CeedCallBackend(CeedVectorGetLength(vec, &length));
445   CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
446 
447   // Is the vector too long to handle with int32? If so, we will divide
448   // it up into "int32-sized" subsections and make repeated BLAS calls.
449   CeedSize num_calls = length / INT_MAX;
450 
451   if (length % INT_MAX > 0) num_calls += 1;
452 
453   // Compute norm
454   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
455   switch (type) {
456     case CEED_NORM_1: {
457       *norm = 0.0;
458       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
459         float  sub_norm = 0.0;
460         float *d_array_start;
461 
462         for (CeedInt i = 0; i < num_calls; i++) {
463           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
464           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
465           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
466 
467           CeedCallHipblas(ceed, hipblasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
468           *norm += sub_norm;
469         }
470       } else {
471         double  sub_norm = 0.0;
472         double *d_array_start;
473 
474         for (CeedInt i = 0; i < num_calls; i++) {
475           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
476           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
477           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
478 
479           CeedCallHipblas(ceed, hipblasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
480           *norm += sub_norm;
481         }
482       }
483       break;
484     }
485     case CEED_NORM_2: {
486       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
487         float  sub_norm = 0.0, norm_sum = 0.0;
488         float *d_array_start;
489 
490         for (CeedInt i = 0; i < num_calls; i++) {
491           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
492           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
493           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
494 
495           CeedCallHipblas(ceed, hipblasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
496           norm_sum += sub_norm * sub_norm;
497         }
498         *norm = sqrt(norm_sum);
499       } else {
500         double  sub_norm = 0.0, norm_sum = 0.0;
501         double *d_array_start;
502 
503         for (CeedInt i = 0; i < num_calls; i++) {
504           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
505           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
506           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
507 
508           CeedCallHipblas(ceed, hipblasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
509           norm_sum += sub_norm * sub_norm;
510         }
511         *norm = sqrt(norm_sum);
512       }
513       break;
514     }
515     case CEED_NORM_MAX: {
516       CeedInt index;
517 
518       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
519         float  sub_max = 0.0, current_max = 0.0;
520         float *d_array_start;
521         for (CeedInt i = 0; i < num_calls; i++) {
522           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
523           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
524           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
525 
526           CeedCallHipblas(ceed, hipblasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &index));
527           CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
528           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
529         }
530         *norm = current_max;
531       } else {
532         double  sub_max = 0.0, current_max = 0.0;
533         double *d_array_start;
534 
535         for (CeedInt i = 0; i < num_calls; i++) {
536           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
537           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
538           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
539 
540           CeedCallHipblas(ceed, hipblasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &index));
541           CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
542           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
543         }
544         *norm = current_max;
545       }
546       break;
547     }
548   }
549   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
550   return CEED_ERROR_SUCCESS;
551 }
552 
553 //------------------------------------------------------------------------------
554 // Take reciprocal of a vector on host
555 //------------------------------------------------------------------------------
556 static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedSize length) {
557   for (CeedSize i = 0; i < length; i++) {
558     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
559   }
560   return CEED_ERROR_SUCCESS;
561 }
562 
563 //------------------------------------------------------------------------------
564 // Take reciprocal of a vector on device (impl in .cu file)
565 //------------------------------------------------------------------------------
566 int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedSize length);
567 
568 //------------------------------------------------------------------------------
569 // Take reciprocal of a vector
570 //------------------------------------------------------------------------------
571 static int CeedVectorReciprocal_Hip(CeedVector vec) {
572   Ceed            ceed;
573   CeedSize        length;
574   CeedVector_Hip *impl;
575 
576   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
577   CeedCallBackend(CeedVectorGetData(vec, &impl));
578   CeedCallBackend(CeedVectorGetLength(vec, &length));
579   // Set value for synced device/host array
580   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Hip(impl->d_array, length));
581   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length));
582   return CEED_ERROR_SUCCESS;
583 }
584 
585 //------------------------------------------------------------------------------
586 // Compute x = alpha x on the host
587 //------------------------------------------------------------------------------
588 static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length) {
589   for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha;
590   return CEED_ERROR_SUCCESS;
591 }
592 
593 //------------------------------------------------------------------------------
594 // Compute x = alpha x on device (impl in .cu file)
595 //------------------------------------------------------------------------------
596 int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length);
597 
598 //------------------------------------------------------------------------------
599 // Compute x = alpha x
600 //------------------------------------------------------------------------------
601 static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
602   Ceed            ceed;
603   CeedSize        length;
604   CeedVector_Hip *x_impl;
605 
606   CeedCallBackend(CeedVectorGetCeed(x, &ceed));
607   CeedCallBackend(CeedVectorGetData(x, &x_impl));
608   CeedCallBackend(CeedVectorGetLength(x, &length));
609   // Set value for synced device/host array
610   if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Hip(x_impl->d_array, alpha, length));
611   if (x_impl->h_array) CeedCallBackend(CeedHostScale_Hip(x_impl->h_array, alpha, length));
612   return CEED_ERROR_SUCCESS;
613 }
614 
615 //------------------------------------------------------------------------------
616 // Compute y = alpha x + y on the host
617 //------------------------------------------------------------------------------
618 static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) {
619   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
620   return CEED_ERROR_SUCCESS;
621 }
622 
623 //------------------------------------------------------------------------------
624 // Compute y = alpha x + y on device (impl in .cu file)
625 //------------------------------------------------------------------------------
626 int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length);
627 
628 //------------------------------------------------------------------------------
629 // Compute y = alpha x + y
630 //------------------------------------------------------------------------------
631 static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
632   Ceed            ceed;
633   CeedSize        length;
634   CeedVector_Hip *y_impl, *x_impl;
635 
636   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
637   CeedCallBackend(CeedVectorGetData(y, &y_impl));
638   CeedCallBackend(CeedVectorGetData(x, &x_impl));
639   CeedCallBackend(CeedVectorGetLength(y, &length));
640   // Set value for synced device/host array
641   if (y_impl->d_array) {
642     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
643     CeedCallBackend(CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length));
644   }
645   if (y_impl->h_array) {
646     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
647     CeedCallBackend(CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length));
648   }
649   return CEED_ERROR_SUCCESS;
650 }
651 
652 //------------------------------------------------------------------------------
653 // Compute y = alpha x + beta y on the host
654 //------------------------------------------------------------------------------
655 static int CeedHostAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) {
656   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i] + beta * y_array[i];
657   return CEED_ERROR_SUCCESS;
658 }
659 
660 //------------------------------------------------------------------------------
661 // Compute y = alpha x + beta y on device (impl in .cu file)
662 //------------------------------------------------------------------------------
663 int CeedDeviceAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length);
664 
665 //------------------------------------------------------------------------------
666 // Compute y = alpha x + beta y
667 //------------------------------------------------------------------------------
668 static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
669   Ceed            ceed;
670   CeedSize        length;
671   CeedVector_Hip *y_impl, *x_impl;
672 
673   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
674   CeedCallBackend(CeedVectorGetData(y, &y_impl));
675   CeedCallBackend(CeedVectorGetData(x, &x_impl));
676   CeedCallBackend(CeedVectorGetLength(y, &length));
677   // Set value for synced device/host array
678   if (y_impl->d_array) {
679     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
680     CeedCallBackend(CeedDeviceAXPBY_Hip(y_impl->d_array, alpha, beta, x_impl->d_array, length));
681   }
682   if (y_impl->h_array) {
683     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
684     CeedCallBackend(CeedHostAXPBY_Hip(y_impl->h_array, alpha, beta, x_impl->h_array, length));
685   }
686   return CEED_ERROR_SUCCESS;
687 }
688 
689 //------------------------------------------------------------------------------
690 // Compute the pointwise multiplication w = x .* y on the host
691 //------------------------------------------------------------------------------
692 static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) {
693   for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
694   return CEED_ERROR_SUCCESS;
695 }
696 
697 //------------------------------------------------------------------------------
698 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
699 //------------------------------------------------------------------------------
700 int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length);
701 
702 //------------------------------------------------------------------------------
703 // Compute the pointwise multiplication w = x .* y
704 //------------------------------------------------------------------------------
705 static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y) {
706   Ceed            ceed;
707   CeedSize        length;
708   CeedVector_Hip *w_impl, *x_impl, *y_impl;
709 
710   CeedCallBackend(CeedVectorGetCeed(w, &ceed));
711   CeedCallBackend(CeedVectorGetData(w, &w_impl));
712   CeedCallBackend(CeedVectorGetData(x, &x_impl));
713   CeedCallBackend(CeedVectorGetData(y, &y_impl));
714   CeedCallBackend(CeedVectorGetLength(w, &length));
715 
716   // Set value for synced device/host array
717   if (!w_impl->d_array && !w_impl->h_array) {
718     CeedCallBackend(CeedVectorSetValue(w, 0.0));
719   }
720   if (w_impl->d_array) {
721     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
722     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
723     CeedCallBackend(CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
724   }
725   if (w_impl->h_array) {
726     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
727     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
728     CeedCallBackend(CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
729   }
730   return CEED_ERROR_SUCCESS;
731 }
732 
733 //------------------------------------------------------------------------------
734 // Destroy the vector
735 //------------------------------------------------------------------------------
736 static int CeedVectorDestroy_Hip(const CeedVector vec) {
737   Ceed            ceed;
738   CeedVector_Hip *impl;
739 
740   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
741   CeedCallBackend(CeedVectorGetData(vec, &impl));
742   CeedCallHip(ceed, hipFree(impl->d_array_owned));
743   CeedCallBackend(CeedFree(&impl->h_array_owned));
744   CeedCallBackend(CeedFree(&impl));
745   return CEED_ERROR_SUCCESS;
746 }
747 
748 //------------------------------------------------------------------------------
749 // Create a vector of the specified length (does not allocate memory)
750 //------------------------------------------------------------------------------
751 int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
752   CeedVector_Hip *impl;
753   Ceed            ceed;
754 
755   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
756   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Hip));
757   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Hip));
758   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Hip));
759   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Hip));
760   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", (int (*)())CeedVectorSetValue_Hip));
761   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Hip));
762   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Hip));
763   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Hip));
764   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Hip));
765   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Hip));
766   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Hip));
767   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", (int (*)())CeedVectorScale_Hip));
768   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", (int (*)())CeedVectorAXPY_Hip));
769   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", (int (*)())CeedVectorAXPBY_Hip));
770   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Hip));
771   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Hip));
772   CeedCallBackend(CeedCalloc(1, &impl));
773   CeedCallBackend(CeedVectorSetData(vec, impl));
774   return CEED_ERROR_SUCCESS;
775 }
776 
777 //------------------------------------------------------------------------------
778