xref: /libCEED/backends/hip-ref/ceed-hip-ref-vector.c (revision f59ebe5e436dfb5c2d53e11c74c720481a8777c4)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <math.h>
11 #include <stdbool.h>
12 #include <string.h>
13 #include <hip/hip_runtime.h>
14 
15 #include "../hip/ceed-hip-common.h"
16 #include "ceed-hip-ref.h"
17 
18 //------------------------------------------------------------------------------
19 // Check if host/device sync is needed
20 //------------------------------------------------------------------------------
21 static inline int CeedVectorNeedSync_Hip(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
22   CeedVector_Hip *impl;
23   bool            has_valid_array = false;
24 
25   CeedCallBackend(CeedVectorGetData(vec, &impl));
26   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
27   switch (mem_type) {
28     case CEED_MEM_HOST:
29       *need_sync = has_valid_array && !impl->h_array;
30       break;
31     case CEED_MEM_DEVICE:
32       *need_sync = has_valid_array && !impl->d_array;
33       break;
34   }
35   return CEED_ERROR_SUCCESS;
36 }
37 
38 //------------------------------------------------------------------------------
39 // Sync host to device
40 //------------------------------------------------------------------------------
41 static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) {
42   Ceed            ceed;
43   CeedSize        length;
44   size_t          bytes;
45   CeedVector_Hip *impl;
46 
47   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
48   CeedCallBackend(CeedVectorGetData(vec, &impl));
49 
50   CeedCheck(impl->h_array, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
51 
52   CeedCallBackend(CeedVectorGetLength(vec, &length));
53   bytes = length * sizeof(CeedScalar);
54   if (impl->d_array_borrowed) {
55     impl->d_array = impl->d_array_borrowed;
56   } else if (impl->d_array_owned) {
57     impl->d_array = impl->d_array_owned;
58   } else {
59     CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes));
60     impl->d_array = impl->d_array_owned;
61   }
62   CeedCallHip(ceed, hipMemcpy(impl->d_array, impl->h_array, bytes, hipMemcpyHostToDevice));
63   return CEED_ERROR_SUCCESS;
64 }
65 
66 //------------------------------------------------------------------------------
67 // Sync device to host
68 //------------------------------------------------------------------------------
69 static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
70   Ceed            ceed;
71   CeedSize        length;
72   size_t          bytes;
73   CeedVector_Hip *impl;
74 
75   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
76   CeedCallBackend(CeedVectorGetData(vec, &impl));
77 
78   CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
79 
80   if (impl->h_array_borrowed) {
81     impl->h_array = impl->h_array_borrowed;
82   } else if (impl->h_array_owned) {
83     impl->h_array = impl->h_array_owned;
84   } else {
85     CeedSize length;
86 
87     CeedCallBackend(CeedVectorGetLength(vec, &length));
88     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
89     impl->h_array = impl->h_array_owned;
90   }
91 
92   CeedCallBackend(CeedVectorGetLength(vec, &length));
93   bytes = length * sizeof(CeedScalar);
94   CeedCallHip(ceed, hipMemcpy(impl->h_array, impl->d_array, bytes, hipMemcpyDeviceToHost));
95   return CEED_ERROR_SUCCESS;
96 }
97 
98 //------------------------------------------------------------------------------
99 // Sync arrays
100 //------------------------------------------------------------------------------
101 static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) {
102   bool need_sync = false;
103 
104   // Check whether device/host sync is needed
105   CeedCallBackend(CeedVectorNeedSync_Hip(vec, mem_type, &need_sync));
106   if (!need_sync) return CEED_ERROR_SUCCESS;
107 
108   switch (mem_type) {
109     case CEED_MEM_HOST:
110       return CeedVectorSyncD2H_Hip(vec);
111     case CEED_MEM_DEVICE:
112       return CeedVectorSyncH2D_Hip(vec);
113   }
114   return CEED_ERROR_UNSUPPORTED;
115 }
116 
117 //------------------------------------------------------------------------------
118 // Set all pointers as invalid
119 //------------------------------------------------------------------------------
120 static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) {
121   CeedVector_Hip *impl;
122 
123   CeedCallBackend(CeedVectorGetData(vec, &impl));
124   impl->h_array = NULL;
125   impl->d_array = NULL;
126   return CEED_ERROR_SUCCESS;
127 }
128 
129 //------------------------------------------------------------------------------
130 // Check if CeedVector has any valid pointer
131 //------------------------------------------------------------------------------
132 static inline int CeedVectorHasValidArray_Hip(const CeedVector vec, bool *has_valid_array) {
133   CeedVector_Hip *impl;
134 
135   CeedCallBackend(CeedVectorGetData(vec, &impl));
136   *has_valid_array = impl->h_array || impl->d_array;
137   return CEED_ERROR_SUCCESS;
138 }
139 
140 //------------------------------------------------------------------------------
141 // Check if has array of given type
142 //------------------------------------------------------------------------------
143 static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
144   CeedVector_Hip *impl;
145 
146   CeedCallBackend(CeedVectorGetData(vec, &impl));
147   switch (mem_type) {
148     case CEED_MEM_HOST:
149       *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned;
150       break;
151     case CEED_MEM_DEVICE:
152       *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned;
153       break;
154   }
155   return CEED_ERROR_SUCCESS;
156 }
157 
158 //------------------------------------------------------------------------------
159 // Check if has borrowed array of given type
160 //------------------------------------------------------------------------------
161 static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
162   CeedVector_Hip *impl;
163 
164   CeedCallBackend(CeedVectorGetData(vec, &impl));
165   switch (mem_type) {
166     case CEED_MEM_HOST:
167       *has_borrowed_array_of_type = impl->h_array_borrowed;
168       break;
169     case CEED_MEM_DEVICE:
170       *has_borrowed_array_of_type = impl->d_array_borrowed;
171       break;
172   }
173   return CEED_ERROR_SUCCESS;
174 }
175 
176 //------------------------------------------------------------------------------
177 // Set array from host
178 //------------------------------------------------------------------------------
179 static int CeedVectorSetArrayHost_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
180   CeedSize        length;
181   CeedVector_Hip *impl;
182 
183   CeedCallBackend(CeedVectorGetData(vec, &impl));
184   CeedCallBackend(CeedVectorGetLength(vec, &length));
185 
186   switch (copy_mode) {
187     case CEED_COPY_VALUES: {
188       if (!impl->h_array_owned) CeedCallBackend(CeedMalloc(length, &impl->h_array_owned));
189       if (array) memcpy(impl->h_array_owned, array, length * sizeof(array[0]));
190       impl->h_array_borrowed = NULL;
191       impl->h_array          = impl->h_array_owned;
192     } break;
193     case CEED_OWN_POINTER:
194       CeedCallBackend(CeedFree(&impl->h_array_owned));
195       impl->h_array_owned    = array;
196       impl->h_array_borrowed = NULL;
197       impl->h_array          = impl->h_array_owned;
198       break;
199     case CEED_USE_POINTER:
200       CeedCallBackend(CeedFree(&impl->h_array_owned));
201       impl->h_array_borrowed = array;
202       impl->h_array          = impl->h_array_borrowed;
203       break;
204   }
205   return CEED_ERROR_SUCCESS;
206 }
207 
208 //------------------------------------------------------------------------------
209 // Set array from device
210 //------------------------------------------------------------------------------
211 static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
212   CeedSize        length;
213   Ceed            ceed;
214   CeedVector_Hip *impl;
215 
216   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
217   CeedCallBackend(CeedVectorGetData(vec, &impl));
218   CeedCallBackend(CeedVectorGetLength(vec, &length));
219   switch (copy_mode) {
220     case CEED_COPY_VALUES: {
221       if (!impl->d_array_owned) CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, length * sizeof(array[0])));
222       if (array) CeedCallHip(ceed, hipMemcpy(impl->d_array_owned, array, length * sizeof(array[0]), hipMemcpyDeviceToDevice));
223       impl->d_array_borrowed = NULL;
224       impl->d_array          = impl->d_array_owned;
225     } break;
226     case CEED_OWN_POINTER:
227       CeedCallHip(ceed, hipFree(impl->d_array_owned));
228       impl->d_array_owned    = array;
229       impl->d_array_borrowed = NULL;
230       impl->d_array          = impl->d_array_owned;
231       break;
232     case CEED_USE_POINTER:
233       CeedCallHip(ceed, hipFree(impl->d_array_owned));
234       impl->d_array_owned    = NULL;
235       impl->d_array_borrowed = array;
236       impl->d_array          = impl->d_array_borrowed;
237       break;
238   }
239   return CEED_ERROR_SUCCESS;
240 }
241 
242 //------------------------------------------------------------------------------
243 // Set the array used by a vector,
244 //   freeing any previously allocated array if applicable
245 //------------------------------------------------------------------------------
246 static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
247   CeedVector_Hip *impl;
248 
249   CeedCallBackend(CeedVectorGetData(vec, &impl));
250   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
251   switch (mem_type) {
252     case CEED_MEM_HOST:
253       return CeedVectorSetArrayHost_Hip(vec, copy_mode, array);
254     case CEED_MEM_DEVICE:
255       return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array);
256   }
257   return CEED_ERROR_UNSUPPORTED;
258 }
259 
260 //------------------------------------------------------------------------------
261 // Set host array to value
262 //------------------------------------------------------------------------------
263 static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedSize length, CeedScalar val) {
264   for (CeedSize i = 0; i < length; i++) h_array[i] = val;
265   return CEED_ERROR_SUCCESS;
266 }
267 
268 //------------------------------------------------------------------------------
269 // Set device array to value (impl in .hip file)
270 //------------------------------------------------------------------------------
271 int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val);
272 
273 //------------------------------------------------------------------------------
274 // Set a vector to a value
275 //------------------------------------------------------------------------------
276 static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
277   CeedSize        length;
278   CeedVector_Hip *impl;
279 
280   CeedCallBackend(CeedVectorGetData(vec, &impl));
281   CeedCallBackend(CeedVectorGetLength(vec, &length));
282   // Set value for synced device/host array
283   if (!impl->d_array && !impl->h_array) {
284     if (impl->d_array_borrowed) {
285       impl->d_array = impl->d_array_borrowed;
286     } else if (impl->h_array_borrowed) {
287       impl->h_array = impl->h_array_borrowed;
288     } else if (impl->d_array_owned) {
289       impl->d_array = impl->d_array_owned;
290     } else if (impl->h_array_owned) {
291       impl->h_array = impl->h_array_owned;
292     } else {
293       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
294     }
295   }
296   if (impl->d_array) {
297     CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val));
298     impl->h_array = NULL;
299   }
300   if (impl->h_array) {
301     CeedCallBackend(CeedHostSetValue_Hip(impl->h_array, length, val));
302     impl->d_array = NULL;
303   }
304   return CEED_ERROR_SUCCESS;
305 }
306 
307 //------------------------------------------------------------------------------
308 // Vector Take Array
309 //------------------------------------------------------------------------------
310 static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
311   CeedVector_Hip *impl;
312 
313   CeedCallBackend(CeedVectorGetData(vec, &impl));
314 
315   // Sync array to requested mem_type
316   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
317 
318   // Update pointer
319   switch (mem_type) {
320     case CEED_MEM_HOST:
321       (*array)               = impl->h_array_borrowed;
322       impl->h_array_borrowed = NULL;
323       impl->h_array          = NULL;
324       break;
325     case CEED_MEM_DEVICE:
326       (*array)               = impl->d_array_borrowed;
327       impl->d_array_borrowed = NULL;
328       impl->d_array          = NULL;
329       break;
330   }
331   return CEED_ERROR_SUCCESS;
332 }
333 
334 //------------------------------------------------------------------------------
335 // Core logic for array syncronization for GetArray.
336 //   If a different memory type is most up to date, this will perform a copy
337 //------------------------------------------------------------------------------
338 static int CeedVectorGetArrayCore_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
339   CeedVector_Hip *impl;
340 
341   CeedCallBackend(CeedVectorGetData(vec, &impl));
342 
343   // Sync array to requested mem_type
344   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
345 
346   // Update pointer
347   switch (mem_type) {
348     case CEED_MEM_HOST:
349       *array = impl->h_array;
350       break;
351     case CEED_MEM_DEVICE:
352       *array = impl->d_array;
353       break;
354   }
355   return CEED_ERROR_SUCCESS;
356 }
357 
358 //------------------------------------------------------------------------------
359 // Get read-only access to a vector via the specified mem_type
360 //------------------------------------------------------------------------------
361 static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
362   return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array);
363 }
364 
365 //------------------------------------------------------------------------------
366 // Get read/write access to a vector via the specified mem_type
367 //------------------------------------------------------------------------------
368 static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
369   CeedVector_Hip *impl;
370 
371   CeedCallBackend(CeedVectorGetData(vec, &impl));
372   CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, mem_type, array));
373   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
374   switch (mem_type) {
375     case CEED_MEM_HOST:
376       impl->h_array = *array;
377       break;
378     case CEED_MEM_DEVICE:
379       impl->d_array = *array;
380       break;
381   }
382   return CEED_ERROR_SUCCESS;
383 }
384 
385 //------------------------------------------------------------------------------
386 // Get write access to a vector via the specified mem_type
387 //------------------------------------------------------------------------------
388 static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
389   bool            has_array_of_type = true;
390   CeedVector_Hip *impl;
391 
392   CeedCallBackend(CeedVectorGetData(vec, &impl));
393   CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type));
394   if (!has_array_of_type) {
395     // Allocate if array is not yet allocated
396     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
397   } else {
398     // Select dirty array
399     switch (mem_type) {
400       case CEED_MEM_HOST:
401         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
402         else impl->h_array = impl->h_array_owned;
403         break;
404       case CEED_MEM_DEVICE:
405         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
406         else impl->d_array = impl->d_array_owned;
407     }
408   }
409   return CeedVectorGetArray_Hip(vec, mem_type, array);
410 }
411 
412 //------------------------------------------------------------------------------
413 // Get the norm of a CeedVector
414 //------------------------------------------------------------------------------
415 static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *norm) {
416   Ceed              ceed;
417   CeedSize          length, num_calls;
418   const CeedScalar *d_array;
419   CeedVector_Hip   *impl;
420   hipblasHandle_t   handle;
421 
422   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
423   CeedCallBackend(CeedVectorGetData(vec, &impl));
424   CeedCallBackend(CeedVectorGetLength(vec, &length));
425   CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
426 
427   // Is the vector too long to handle with int32? If so, we will divide
428   // it up into "int32-sized" subsections and make repeated BLAS calls.
429   num_calls = length / INT_MAX;
430   if (length % INT_MAX > 0) num_calls += 1;
431 
432   // Compute norm
433   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
434   switch (type) {
435     case CEED_NORM_1: {
436       *norm = 0.0;
437       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
438         float  sub_norm = 0.0;
439         float *d_array_start;
440 
441         for (CeedInt i = 0; i < num_calls; i++) {
442           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
443           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
444           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
445 
446           CeedCallHipblas(ceed, hipblasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
447           *norm += sub_norm;
448         }
449       } else {
450         double  sub_norm = 0.0;
451         double *d_array_start;
452 
453         for (CeedInt i = 0; i < num_calls; i++) {
454           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
455           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
456           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
457 
458           CeedCallHipblas(ceed, hipblasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
459           *norm += sub_norm;
460         }
461       }
462       break;
463     }
464     case CEED_NORM_2: {
465       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
466         float  sub_norm = 0.0, norm_sum = 0.0;
467         float *d_array_start;
468 
469         for (CeedInt i = 0; i < num_calls; i++) {
470           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
471           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
472           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
473 
474           CeedCallHipblas(ceed, hipblasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
475           norm_sum += sub_norm * sub_norm;
476         }
477         *norm = sqrt(norm_sum);
478       } else {
479         double  sub_norm = 0.0, norm_sum = 0.0;
480         double *d_array_start;
481 
482         for (CeedInt i = 0; i < num_calls; i++) {
483           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
484           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
485           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
486 
487           CeedCallHipblas(ceed, hipblasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
488           norm_sum += sub_norm * sub_norm;
489         }
490         *norm = sqrt(norm_sum);
491       }
492       break;
493     }
494     case CEED_NORM_MAX: {
495       CeedInt index;
496 
497       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
498         float  sub_max = 0.0, current_max = 0.0;
499         float *d_array_start;
500         for (CeedInt i = 0; i < num_calls; i++) {
501           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
502           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
503           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
504 
505           CeedCallHipblas(ceed, hipblasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &index));
506           CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
507           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
508         }
509         *norm = current_max;
510       } else {
511         double  sub_max = 0.0, current_max = 0.0;
512         double *d_array_start;
513 
514         for (CeedInt i = 0; i < num_calls; i++) {
515           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
516           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
517           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
518 
519           CeedCallHipblas(ceed, hipblasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &index));
520           CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
521           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
522         }
523         *norm = current_max;
524       }
525       break;
526     }
527   }
528   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
529   return CEED_ERROR_SUCCESS;
530 }
531 
532 //------------------------------------------------------------------------------
533 // Take reciprocal of a vector on host
534 //------------------------------------------------------------------------------
535 static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedSize length) {
536   for (CeedSize i = 0; i < length; i++) {
537     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
538   }
539   return CEED_ERROR_SUCCESS;
540 }
541 
542 //------------------------------------------------------------------------------
543 // Take reciprocal of a vector on device (impl in .cu file)
544 //------------------------------------------------------------------------------
545 int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedSize length);
546 
547 //------------------------------------------------------------------------------
548 // Take reciprocal of a vector
549 //------------------------------------------------------------------------------
550 static int CeedVectorReciprocal_Hip(CeedVector vec) {
551   CeedSize        length;
552   CeedVector_Hip *impl;
553 
554   CeedCallBackend(CeedVectorGetData(vec, &impl));
555   CeedCallBackend(CeedVectorGetLength(vec, &length));
556   // Set value for synced device/host array
557   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Hip(impl->d_array, length));
558   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length));
559   return CEED_ERROR_SUCCESS;
560 }
561 
562 //------------------------------------------------------------------------------
563 // Compute x = alpha x on the host
564 //------------------------------------------------------------------------------
565 static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length) {
566   for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha;
567   return CEED_ERROR_SUCCESS;
568 }
569 
570 //------------------------------------------------------------------------------
571 // Compute x = alpha x on device (impl in .cu file)
572 //------------------------------------------------------------------------------
573 int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length);
574 
575 //------------------------------------------------------------------------------
576 // Compute x = alpha x
577 //------------------------------------------------------------------------------
578 static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
579   CeedSize        length;
580   CeedVector_Hip *x_impl;
581 
582   CeedCallBackend(CeedVectorGetData(x, &x_impl));
583   CeedCallBackend(CeedVectorGetLength(x, &length));
584   // Set value for synced device/host array
585   if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Hip(x_impl->d_array, alpha, length));
586   if (x_impl->h_array) CeedCallBackend(CeedHostScale_Hip(x_impl->h_array, alpha, length));
587   return CEED_ERROR_SUCCESS;
588 }
589 
590 //------------------------------------------------------------------------------
591 // Compute y = alpha x + y on the host
592 //------------------------------------------------------------------------------
593 static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) {
594   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
595   return CEED_ERROR_SUCCESS;
596 }
597 
598 //------------------------------------------------------------------------------
599 // Compute y = alpha x + y on device (impl in .cu file)
600 //------------------------------------------------------------------------------
601 int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length);
602 
603 //------------------------------------------------------------------------------
604 // Compute y = alpha x + y
605 //------------------------------------------------------------------------------
606 static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
607   CeedSize        length;
608   CeedVector_Hip *y_impl, *x_impl;
609 
610   CeedCallBackend(CeedVectorGetData(y, &y_impl));
611   CeedCallBackend(CeedVectorGetData(x, &x_impl));
612   CeedCallBackend(CeedVectorGetLength(y, &length));
613   // Set value for synced device/host array
614   if (y_impl->d_array) {
615     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
616     CeedCallBackend(CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length));
617   }
618   if (y_impl->h_array) {
619     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
620     CeedCallBackend(CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length));
621   }
622   return CEED_ERROR_SUCCESS;
623 }
624 
625 //------------------------------------------------------------------------------
626 // Compute y = alpha x + beta y on the host
627 //------------------------------------------------------------------------------
628 static int CeedHostAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) {
629   for (CeedSize i = 0; i < length; i++) y_array[i] = alpha * x_array[i] + beta * y_array[i];
630   return CEED_ERROR_SUCCESS;
631 }
632 
633 //------------------------------------------------------------------------------
634 // Compute y = alpha x + beta y on device (impl in .cu file)
635 //------------------------------------------------------------------------------
636 int CeedDeviceAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length);
637 
638 //------------------------------------------------------------------------------
639 // Compute y = alpha x + beta y
640 //------------------------------------------------------------------------------
641 static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
642   CeedSize        length;
643   CeedVector_Hip *y_impl, *x_impl;
644 
645   CeedCallBackend(CeedVectorGetData(y, &y_impl));
646   CeedCallBackend(CeedVectorGetData(x, &x_impl));
647   CeedCallBackend(CeedVectorGetLength(y, &length));
648   // Set value for synced device/host array
649   if (y_impl->d_array) {
650     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
651     CeedCallBackend(CeedDeviceAXPBY_Hip(y_impl->d_array, alpha, beta, x_impl->d_array, length));
652   }
653   if (y_impl->h_array) {
654     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
655     CeedCallBackend(CeedHostAXPBY_Hip(y_impl->h_array, alpha, beta, x_impl->h_array, length));
656   }
657   return CEED_ERROR_SUCCESS;
658 }
659 
660 //------------------------------------------------------------------------------
661 // Compute the pointwise multiplication w = x .* y on the host
662 //------------------------------------------------------------------------------
663 static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) {
664   for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
665   return CEED_ERROR_SUCCESS;
666 }
667 
668 //------------------------------------------------------------------------------
669 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
670 //------------------------------------------------------------------------------
671 int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length);
672 
673 //------------------------------------------------------------------------------
674 // Compute the pointwise multiplication w = x .* y
675 //------------------------------------------------------------------------------
676 static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y) {
677   CeedSize        length;
678   CeedVector_Hip *w_impl, *x_impl, *y_impl;
679 
680   CeedCallBackend(CeedVectorGetData(w, &w_impl));
681   CeedCallBackend(CeedVectorGetData(x, &x_impl));
682   CeedCallBackend(CeedVectorGetData(y, &y_impl));
683   CeedCallBackend(CeedVectorGetLength(w, &length));
684 
685   // Set value for synced device/host array
686   if (!w_impl->d_array && !w_impl->h_array) {
687     CeedCallBackend(CeedVectorSetValue(w, 0.0));
688   }
689   if (w_impl->d_array) {
690     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
691     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
692     CeedCallBackend(CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
693   }
694   if (w_impl->h_array) {
695     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
696     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
697     CeedCallBackend(CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
698   }
699   return CEED_ERROR_SUCCESS;
700 }
701 
702 //------------------------------------------------------------------------------
703 // Destroy the vector
704 //------------------------------------------------------------------------------
705 static int CeedVectorDestroy_Hip(const CeedVector vec) {
706   CeedVector_Hip *impl;
707 
708   CeedCallBackend(CeedVectorGetData(vec, &impl));
709   CeedCallHip(CeedVectorReturnCeed(vec), hipFree(impl->d_array_owned));
710   CeedCallBackend(CeedFree(&impl->h_array_owned));
711   CeedCallBackend(CeedFree(&impl));
712   return CEED_ERROR_SUCCESS;
713 }
714 
715 //------------------------------------------------------------------------------
716 // Create a vector of the specified length (does not allocate memory)
717 //------------------------------------------------------------------------------
718 int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
719   CeedVector_Hip *impl;
720   Ceed            ceed;
721 
722   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
723   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Hip));
724   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Hip));
725   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Hip));
726   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Hip));
727   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", (int (*)())CeedVectorSetValue_Hip));
728   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Hip));
729   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Hip));
730   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Hip));
731   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Hip));
732   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Hip));
733   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Hip));
734   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", (int (*)())CeedVectorScale_Hip));
735   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", (int (*)())CeedVectorAXPY_Hip));
736   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", (int (*)())CeedVectorAXPBY_Hip));
737   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Hip));
738   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Hip));
739   CeedCallBackend(CeedCalloc(1, &impl));
740   CeedCallBackend(CeedVectorSetData(vec, impl));
741   return CEED_ERROR_SUCCESS;
742 }
743 
744 //------------------------------------------------------------------------------
745