xref: /libCEED/backends/hip-ref/ceed-hip-ref-vector.c (revision 3d13c0f212be81c681320aed470afb2f8a24bbef)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <math.h>
11 #include <stdbool.h>
12 #include <string.h>
13 #include <hip/hip_runtime.h>
14 
15 #include "../hip/ceed-hip-common.h"
16 #include "ceed-hip-ref.h"
17 
18 //------------------------------------------------------------------------------
19 // Check if host/device sync is needed
20 //------------------------------------------------------------------------------
21 static inline int CeedVectorNeedSync_Hip(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
22   CeedVector_Hip *impl;
23   CeedCallBackend(CeedVectorGetData(vec, &impl));
24 
25   bool has_valid_array = false;
26   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
27   switch (mem_type) {
28     case CEED_MEM_HOST:
29       *need_sync = has_valid_array && !impl->h_array;
30       break;
31     case CEED_MEM_DEVICE:
32       *need_sync = has_valid_array && !impl->d_array;
33       break;
34   }
35 
36   return CEED_ERROR_SUCCESS;
37 }
38 
39 //------------------------------------------------------------------------------
40 // Sync host to device
41 //------------------------------------------------------------------------------
42 static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) {
43   Ceed ceed;
44   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
45   CeedVector_Hip *impl;
46   CeedCallBackend(CeedVectorGetData(vec, &impl));
47 
48   CeedSize length;
49   CeedCallBackend(CeedVectorGetLength(vec, &length));
50   size_t bytes = length * sizeof(CeedScalar);
51 
52   CeedCheck(impl->h_array, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
53 
54   if (impl->d_array_borrowed) {
55     impl->d_array = impl->d_array_borrowed;
56   } else if (impl->d_array_owned) {
57     impl->d_array = impl->d_array_owned;
58   } else {
59     CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes));
60     impl->d_array = impl->d_array_owned;
61   }
62 
63   CeedCallHip(ceed, hipMemcpy(impl->d_array, impl->h_array, bytes, hipMemcpyHostToDevice));
64 
65   return CEED_ERROR_SUCCESS;
66 }
67 
68 //------------------------------------------------------------------------------
69 // Sync device to host
70 //------------------------------------------------------------------------------
71 static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
72   Ceed ceed;
73   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
74   CeedVector_Hip *impl;
75   CeedCallBackend(CeedVectorGetData(vec, &impl));
76 
77   CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
78 
79   if (impl->h_array_borrowed) {
80     impl->h_array = impl->h_array_borrowed;
81   } else if (impl->h_array_owned) {
82     impl->h_array = impl->h_array_owned;
83   } else {
84     CeedSize length;
85     CeedCallBackend(CeedVectorGetLength(vec, &length));
86     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
87     impl->h_array = impl->h_array_owned;
88   }
89 
90   CeedSize length;
91   CeedCallBackend(CeedVectorGetLength(vec, &length));
92   size_t bytes = length * sizeof(CeedScalar);
93   CeedCallHip(ceed, hipMemcpy(impl->h_array, impl->d_array, bytes, hipMemcpyDeviceToHost));
94 
95   return CEED_ERROR_SUCCESS;
96 }
97 
98 //------------------------------------------------------------------------------
99 // Sync arrays
100 //------------------------------------------------------------------------------
101 static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) {
102   // Check whether device/host sync is needed
103   bool need_sync = false;
104   CeedCallBackend(CeedVectorNeedSync_Hip(vec, mem_type, &need_sync));
105   if (!need_sync) return CEED_ERROR_SUCCESS;
106 
107   switch (mem_type) {
108     case CEED_MEM_HOST:
109       return CeedVectorSyncD2H_Hip(vec);
110     case CEED_MEM_DEVICE:
111       return CeedVectorSyncH2D_Hip(vec);
112   }
113   return CEED_ERROR_UNSUPPORTED;
114 }
115 
116 //------------------------------------------------------------------------------
117 // Set all pointers as invalid
118 //------------------------------------------------------------------------------
119 static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) {
120   CeedVector_Hip *impl;
121   CeedCallBackend(CeedVectorGetData(vec, &impl));
122 
123   impl->h_array = NULL;
124   impl->d_array = NULL;
125 
126   return CEED_ERROR_SUCCESS;
127 }
128 
129 //------------------------------------------------------------------------------
130 // Check if CeedVector has any valid pointers
131 //------------------------------------------------------------------------------
132 static inline int CeedVectorHasValidArray_Hip(const CeedVector vec, bool *has_valid_array) {
133   CeedVector_Hip *impl;
134   CeedCallBackend(CeedVectorGetData(vec, &impl));
135 
136   *has_valid_array = !!impl->h_array || !!impl->d_array;
137 
138   return CEED_ERROR_SUCCESS;
139 }
140 
141 //------------------------------------------------------------------------------
142 // Check if has any array of given type
143 //------------------------------------------------------------------------------
144 static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
145   CeedVector_Hip *impl;
146   CeedCallBackend(CeedVectorGetData(vec, &impl));
147 
148   switch (mem_type) {
149     case CEED_MEM_HOST:
150       *has_array_of_type = !!impl->h_array_borrowed || !!impl->h_array_owned;
151       break;
152     case CEED_MEM_DEVICE:
153       *has_array_of_type = !!impl->d_array_borrowed || !!impl->d_array_owned;
154       break;
155   }
156 
157   return CEED_ERROR_SUCCESS;
158 }
159 
160 //------------------------------------------------------------------------------
161 // Check if has borrowed array of given type
162 //------------------------------------------------------------------------------
163 static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
164   CeedVector_Hip *impl;
165   CeedCallBackend(CeedVectorGetData(vec, &impl));
166 
167   switch (mem_type) {
168     case CEED_MEM_HOST:
169       *has_borrowed_array_of_type = !!impl->h_array_borrowed;
170       break;
171     case CEED_MEM_DEVICE:
172       *has_borrowed_array_of_type = !!impl->d_array_borrowed;
173       break;
174   }
175 
176   return CEED_ERROR_SUCCESS;
177 }
178 
179 //------------------------------------------------------------------------------
180 // Set array from host
181 //------------------------------------------------------------------------------
182 static int CeedVectorSetArrayHost_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
183   CeedVector_Hip *impl;
184   CeedCallBackend(CeedVectorGetData(vec, &impl));
185 
186   switch (copy_mode) {
187     case CEED_COPY_VALUES: {
188       CeedSize length;
189       if (!impl->h_array_owned) {
190         CeedCallBackend(CeedVectorGetLength(vec, &length));
191         CeedCallBackend(CeedMalloc(length, &impl->h_array_owned));
192       }
193       impl->h_array_borrowed = NULL;
194       impl->h_array          = impl->h_array_owned;
195       if (array) {
196         CeedSize length;
197         CeedCallBackend(CeedVectorGetLength(vec, &length));
198         size_t bytes = length * sizeof(CeedScalar);
199         memcpy(impl->h_array, array, bytes);
200       }
201     } break;
202     case CEED_OWN_POINTER:
203       CeedCallBackend(CeedFree(&impl->h_array_owned));
204       impl->h_array_owned    = array;
205       impl->h_array_borrowed = NULL;
206       impl->h_array          = array;
207       break;
208     case CEED_USE_POINTER:
209       CeedCallBackend(CeedFree(&impl->h_array_owned));
210       impl->h_array_borrowed = array;
211       impl->h_array          = array;
212       break;
213   }
214 
215   return CEED_ERROR_SUCCESS;
216 }
217 
218 //------------------------------------------------------------------------------
219 // Set array from device
220 //------------------------------------------------------------------------------
221 static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
222   Ceed ceed;
223   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
224   CeedVector_Hip *impl;
225   CeedCallBackend(CeedVectorGetData(vec, &impl));
226 
227   switch (copy_mode) {
228     case CEED_COPY_VALUES: {
229       CeedSize length;
230       CeedCallBackend(CeedVectorGetLength(vec, &length));
231       size_t bytes = length * sizeof(CeedScalar);
232       if (!impl->d_array_owned) {
233         CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes));
234       }
235       impl->d_array_borrowed = NULL;
236       impl->d_array          = impl->d_array_owned;
237       if (array) {
238         CeedCallHip(ceed, hipMemcpy(impl->d_array, array, bytes, hipMemcpyDeviceToDevice));
239       }
240     } break;
241     case CEED_OWN_POINTER:
242       CeedCallHip(ceed, hipFree(impl->d_array_owned));
243       impl->d_array_owned    = array;
244       impl->d_array_borrowed = NULL;
245       impl->d_array          = array;
246       break;
247     case CEED_USE_POINTER:
248       CeedCallHip(ceed, hipFree(impl->d_array_owned));
249       impl->d_array_owned    = NULL;
250       impl->d_array_borrowed = array;
251       impl->d_array          = array;
252       break;
253   }
254 
255   return CEED_ERROR_SUCCESS;
256 }
257 
258 //------------------------------------------------------------------------------
259 // Set the array used by a vector,
260 //   freeing any previously allocated array if applicable
261 //------------------------------------------------------------------------------
262 static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
263   Ceed ceed;
264   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
265   CeedVector_Hip *impl;
266   CeedCallBackend(CeedVectorGetData(vec, &impl));
267 
268   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
269   switch (mem_type) {
270     case CEED_MEM_HOST:
271       return CeedVectorSetArrayHost_Hip(vec, copy_mode, array);
272     case CEED_MEM_DEVICE:
273       return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array);
274   }
275 
276   return CEED_ERROR_UNSUPPORTED;
277 }
278 
279 //------------------------------------------------------------------------------
280 // Set host array to value
281 //------------------------------------------------------------------------------
282 static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedSize length, CeedScalar val) {
283   for (CeedSize i = 0; i < length; i++) h_array[i] = val;
284   return CEED_ERROR_SUCCESS;
285 }
286 
287 //------------------------------------------------------------------------------
288 // Set device array to value (impl in .hip file)
289 //------------------------------------------------------------------------------
290 int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val);
291 
292 //------------------------------------------------------------------------------
293 // Set a vector to a value,
294 //------------------------------------------------------------------------------
295 static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
296   Ceed ceed;
297   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
298   CeedVector_Hip *impl;
299   CeedCallBackend(CeedVectorGetData(vec, &impl));
300   CeedSize length;
301   CeedCallBackend(CeedVectorGetLength(vec, &length));
302 
303   // Set value for synced device/host array
304   if (!impl->d_array && !impl->h_array) {
305     if (impl->d_array_borrowed) {
306       impl->d_array = impl->d_array_borrowed;
307     } else if (impl->h_array_borrowed) {
308       impl->h_array = impl->h_array_borrowed;
309     } else if (impl->d_array_owned) {
310       impl->d_array = impl->d_array_owned;
311     } else if (impl->h_array_owned) {
312       impl->h_array = impl->h_array_owned;
313     } else {
314       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
315     }
316   }
317   if (impl->d_array) {
318     CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val));
319   }
320   if (impl->h_array) {
321     CeedCallBackend(CeedHostSetValue_Hip(impl->h_array, length, val));
322   }
323 
324   return CEED_ERROR_SUCCESS;
325 }
326 
327 //------------------------------------------------------------------------------
328 // Vector Take Array
329 //------------------------------------------------------------------------------
330 static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
331   Ceed ceed;
332   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
333   CeedVector_Hip *impl;
334   CeedCallBackend(CeedVectorGetData(vec, &impl));
335 
336   // Sync array to requested mem_type
337   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
338 
339   // Update pointer
340   switch (mem_type) {
341     case CEED_MEM_HOST:
342       (*array)               = impl->h_array_borrowed;
343       impl->h_array_borrowed = NULL;
344       impl->h_array          = NULL;
345       break;
346     case CEED_MEM_DEVICE:
347       (*array)               = impl->d_array_borrowed;
348       impl->d_array_borrowed = NULL;
349       impl->d_array          = NULL;
350       break;
351   }
352 
353   return CEED_ERROR_SUCCESS;
354 }
355 
356 //------------------------------------------------------------------------------
357 // Core logic for array syncronization for GetArray.
358 //   If a different memory type is most up to date, this will perform a copy
359 //------------------------------------------------------------------------------
360 static int CeedVectorGetArrayCore_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
361   Ceed ceed;
362   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
363   CeedVector_Hip *impl;
364   CeedCallBackend(CeedVectorGetData(vec, &impl));
365 
366   // Sync array to requested mem_type
367   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
368 
369   // Update pointer
370   switch (mem_type) {
371     case CEED_MEM_HOST:
372       *array = impl->h_array;
373       break;
374     case CEED_MEM_DEVICE:
375       *array = impl->d_array;
376       break;
377   }
378 
379   return CEED_ERROR_SUCCESS;
380 }
381 
382 //------------------------------------------------------------------------------
383 // Get read-only access to a vector via the specified mem_type
384 //------------------------------------------------------------------------------
385 static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
386   return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array);
387 }
388 
389 //------------------------------------------------------------------------------
390 // Get read/write access to a vector via the specified mem_type
391 //------------------------------------------------------------------------------
392 static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
393   CeedVector_Hip *impl;
394   CeedCallBackend(CeedVectorGetData(vec, &impl));
395 
396   CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, mem_type, array));
397 
398   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
399   switch (mem_type) {
400     case CEED_MEM_HOST:
401       impl->h_array = *array;
402       break;
403     case CEED_MEM_DEVICE:
404       impl->d_array = *array;
405       break;
406   }
407 
408   return CEED_ERROR_SUCCESS;
409 }
410 
411 //------------------------------------------------------------------------------
412 // Get write access to a vector via the specified mem_type
413 //------------------------------------------------------------------------------
414 static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
415   CeedVector_Hip *impl;
416   CeedCallBackend(CeedVectorGetData(vec, &impl));
417 
418   bool has_array_of_type = true;
419   CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type));
420   if (!has_array_of_type) {
421     // Allocate if array is not yet allocated
422     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
423   } else {
424     // Select dirty array
425     switch (mem_type) {
426       case CEED_MEM_HOST:
427         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
428         else impl->h_array = impl->h_array_owned;
429         break;
430       case CEED_MEM_DEVICE:
431         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
432         else impl->d_array = impl->d_array_owned;
433     }
434   }
435 
436   return CeedVectorGetArray_Hip(vec, mem_type, array);
437 }
438 
439 //------------------------------------------------------------------------------
440 // Get the norm of a CeedVector
441 //------------------------------------------------------------------------------
442 static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *norm) {
443   Ceed ceed;
444   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
445   CeedVector_Hip *impl;
446   CeedCallBackend(CeedVectorGetData(vec, &impl));
447   CeedSize length;
448   CeedCallBackend(CeedVectorGetLength(vec, &length));
449   hipblasHandle_t handle;
450   CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
451 
452   // Is the vector too long to handle with int32? If so, we will divide
453   // it up into "int32-sized" subsections and make repeated BLAS calls.
454   CeedSize num_calls = length / INT_MAX;
455   if (length % INT_MAX > 0) num_calls += 1;
456 
457   // Compute norm
458   const CeedScalar *d_array;
459   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
460   switch (type) {
461     case CEED_NORM_1: {
462       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
463         float  sub_norm = 0.0;
464         float *d_array_start;
465         for (CeedInt i = 0; i < num_calls; i++) {
466           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
467           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
468           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
469           CeedCallHipblas(ceed, hipblasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
470           *norm += sub_norm;
471         }
472       } else {
473         double  sub_norm = 0.0;
474         double *d_array_start;
475         for (CeedInt i = 0; i < num_calls; i++) {
476           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
477           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
478           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
479           CeedCallHipblas(ceed, hipblasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
480           *norm += sub_norm;
481         }
482       }
483       break;
484     }
485     case CEED_NORM_2: {
486       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
487         float  sub_norm = 0.0, norm_sum = 0.0;
488         float *d_array_start;
489         for (CeedInt i = 0; i < num_calls; i++) {
490           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
491           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
492           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
493           CeedCallHipblas(ceed, hipblasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
494           norm_sum += sub_norm * sub_norm;
495         }
496         *norm = sqrt(norm_sum);
497       } else {
498         double  sub_norm = 0.0, norm_sum = 0.0;
499         double *d_array_start;
500         for (CeedInt i = 0; i < num_calls; i++) {
501           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
502           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
503           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
504           CeedCallHipblas(ceed, hipblasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
505           norm_sum += sub_norm * sub_norm;
506         }
507         *norm = sqrt(norm_sum);
508       }
509       break;
510     }
511     case CEED_NORM_MAX: {
512       CeedInt indx;
513       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
514         float  sub_max = 0.0, current_max = 0.0;
515         float *d_array_start;
516         for (CeedInt i = 0; i < num_calls; i++) {
517           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
518           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
519           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
520           CeedCallHipblas(ceed, hipblasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &indx));
521           CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + indx - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
522           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
523         }
524         *norm = current_max;
525       } else {
526         double  sub_max = 0.0, current_max = 0.0;
527         double *d_array_start;
528         for (CeedInt i = 0; i < num_calls; i++) {
529           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
530           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
531           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
532           CeedCallHipblas(ceed, hipblasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &indx));
533           CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + indx - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
534           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
535         }
536         *norm = current_max;
537       }
538       break;
539     }
540   }
541   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
542 
543   return CEED_ERROR_SUCCESS;
544 }
545 
546 //------------------------------------------------------------------------------
547 // Take reciprocal of a vector on host
548 //------------------------------------------------------------------------------
549 static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedSize length) {
550   for (CeedSize i = 0; i < length; i++) {
551     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
552   }
553   return CEED_ERROR_SUCCESS;
554 }
555 
556 //------------------------------------------------------------------------------
557 // Take reciprocal of a vector on device (impl in .cu file)
558 //------------------------------------------------------------------------------
559 int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedSize length);
560 
561 //------------------------------------------------------------------------------
562 // Take reciprocal of a vector
563 //------------------------------------------------------------------------------
564 static int CeedVectorReciprocal_Hip(CeedVector vec) {
565   Ceed ceed;
566   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
567   CeedVector_Hip *impl;
568   CeedCallBackend(CeedVectorGetData(vec, &impl));
569   CeedSize length;
570   CeedCallBackend(CeedVectorGetLength(vec, &length));
571 
572   // Set value for synced device/host array
573   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Hip(impl->d_array, length));
574   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length));
575 
576   return CEED_ERROR_SUCCESS;
577 }
578 
579 //------------------------------------------------------------------------------
580 // Compute x = alpha x on the host
581 //------------------------------------------------------------------------------
582 static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length) {
583   for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha;
584   return CEED_ERROR_SUCCESS;
585 }
586 
587 //------------------------------------------------------------------------------
588 // Compute x = alpha x on device (impl in .cu file)
589 //------------------------------------------------------------------------------
590 int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length);
591 
592 //------------------------------------------------------------------------------
593 // Compute x = alpha x
594 //------------------------------------------------------------------------------
595 static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
596   Ceed ceed;
597   CeedCallBackend(CeedVectorGetCeed(x, &ceed));
598   CeedVector_Hip *x_impl;
599   CeedCallBackend(CeedVectorGetData(x, &x_impl));
600   CeedSize length;
601   CeedCallBackend(CeedVectorGetLength(x, &length));
602 
603   // Set value for synced device/host array
604   if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Hip(x_impl->d_array, alpha, length));
605   if (x_impl->h_array) CeedCallBackend(CeedHostScale_Hip(x_impl->h_array, alpha, length));
606 
607   return CEED_ERROR_SUCCESS;
608 }
609 
610 //------------------------------------------------------------------------------
611 // Compute y = alpha x + y on the host
612 //------------------------------------------------------------------------------
613 static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) {
614   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
615   return CEED_ERROR_SUCCESS;
616 }
617 
618 //------------------------------------------------------------------------------
619 // Compute y = alpha x + y on device (impl in .cu file)
620 //------------------------------------------------------------------------------
621 int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length);
622 
623 //------------------------------------------------------------------------------
624 // Compute y = alpha x + y
625 //------------------------------------------------------------------------------
626 static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
627   Ceed ceed;
628   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
629   CeedVector_Hip *y_impl, *x_impl;
630   CeedCallBackend(CeedVectorGetData(y, &y_impl));
631   CeedCallBackend(CeedVectorGetData(x, &x_impl));
632   CeedSize length;
633   CeedCallBackend(CeedVectorGetLength(y, &length));
634 
635   // Set value for synced device/host array
636   if (y_impl->d_array) {
637     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
638     CeedCallBackend(CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length));
639   }
640   if (y_impl->h_array) {
641     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
642     CeedCallBackend(CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length));
643   }
644 
645   return CEED_ERROR_SUCCESS;
646 }
647 
648 //------------------------------------------------------------------------------
649 // Compute y = alpha x + beta y on the host
650 //------------------------------------------------------------------------------
651 static int CeedHostAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) {
652   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i] + beta * y_array[i];
653   return CEED_ERROR_SUCCESS;
654 }
655 
656 //------------------------------------------------------------------------------
657 // Compute y = alpha x + beta y on device (impl in .cu file)
658 //------------------------------------------------------------------------------
659 int CeedDeviceAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length);
660 
661 //------------------------------------------------------------------------------
662 // Compute y = alpha x + beta y
663 //------------------------------------------------------------------------------
664 static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
665   Ceed ceed;
666   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
667   CeedVector_Hip *y_impl, *x_impl;
668   CeedCallBackend(CeedVectorGetData(y, &y_impl));
669   CeedCallBackend(CeedVectorGetData(x, &x_impl));
670   CeedSize length;
671   CeedCallBackend(CeedVectorGetLength(y, &length));
672 
673   // Set value for synced device/host array
674   if (y_impl->d_array) {
675     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
676     CeedCallBackend(CeedDeviceAXPBY_Hip(y_impl->d_array, alpha, beta, x_impl->d_array, length));
677   }
678   if (y_impl->h_array) {
679     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
680     CeedCallBackend(CeedHostAXPBY_Hip(y_impl->h_array, alpha, beta, x_impl->h_array, length));
681   }
682 
683   return CEED_ERROR_SUCCESS;
684 }
685 
686 //------------------------------------------------------------------------------
687 // Compute the pointwise multiplication w = x .* y on the host
688 //------------------------------------------------------------------------------
689 static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) {
690   for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
691   return CEED_ERROR_SUCCESS;
692 }
693 
694 //------------------------------------------------------------------------------
695 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
696 //------------------------------------------------------------------------------
697 int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length);
698 
699 //------------------------------------------------------------------------------
700 // Compute the pointwise multiplication w = x .* y
701 //------------------------------------------------------------------------------
702 static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y) {
703   Ceed ceed;
704   CeedCallBackend(CeedVectorGetCeed(w, &ceed));
705   CeedVector_Hip *w_impl, *x_impl, *y_impl;
706   CeedCallBackend(CeedVectorGetData(w, &w_impl));
707   CeedCallBackend(CeedVectorGetData(x, &x_impl));
708   CeedCallBackend(CeedVectorGetData(y, &y_impl));
709   CeedSize length;
710   CeedCallBackend(CeedVectorGetLength(w, &length));
711 
712   // Set value for synced device/host array
713   if (!w_impl->d_array && !w_impl->h_array) {
714     CeedCallBackend(CeedVectorSetValue(w, 0.0));
715   }
716   if (w_impl->d_array) {
717     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
718     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
719     CeedCallBackend(CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
720   }
721   if (w_impl->h_array) {
722     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
723     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
724     CeedCallBackend(CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
725   }
726 
727   return CEED_ERROR_SUCCESS;
728 }
729 
730 //------------------------------------------------------------------------------
731 // Destroy the vector
732 //------------------------------------------------------------------------------
733 static int CeedVectorDestroy_Hip(const CeedVector vec) {
734   Ceed ceed;
735   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
736   CeedVector_Hip *impl;
737   CeedCallBackend(CeedVectorGetData(vec, &impl));
738 
739   CeedCallHip(ceed, hipFree(impl->d_array_owned));
740   CeedCallBackend(CeedFree(&impl->h_array_owned));
741   CeedCallBackend(CeedFree(&impl));
742 
743   return CEED_ERROR_SUCCESS;
744 }
745 
746 //------------------------------------------------------------------------------
747 // Create a vector of the specified length (does not allocate memory)
748 //------------------------------------------------------------------------------
749 int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
750   CeedVector_Hip *impl;
751   Ceed            ceed;
752   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
753 
754   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Hip));
755   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Hip));
756   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Hip));
757   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Hip));
758   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", (int (*)())CeedVectorSetValue_Hip));
759   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Hip));
760   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Hip));
761   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Hip));
762   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Hip));
763   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Hip));
764   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Hip));
765   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", (int (*)())CeedVectorScale_Hip));
766   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", (int (*)())CeedVectorAXPY_Hip));
767   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", (int (*)())CeedVectorAXPBY_Hip));
768   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Hip));
769   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Hip));
770 
771   CeedCallBackend(CeedCalloc(1, &impl));
772   CeedCallBackend(CeedVectorSetData(vec, impl));
773 
774   return CEED_ERROR_SUCCESS;
775 }
776 
777 //------------------------------------------------------------------------------
778