xref: /libCEED/backends/hip-ref/ceed-hip-ref-vector.c (revision baf96a30fc83f1cce83f61e183e51df19381d4f1)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <hip/hip_runtime.h>
11 #include <math.h>
12 #include <stdbool.h>
13 #include <string.h>
14 
15 #include "../hip/ceed-hip-common.h"
16 #include "ceed-hip-ref.h"
17 
18 //------------------------------------------------------------------------------
19 // Check if host/device sync is needed
20 //------------------------------------------------------------------------------
21 static inline int CeedVectorNeedSync_Hip(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
22   CeedVector_Hip *impl;
23   CeedCallBackend(CeedVectorGetData(vec, &impl));
24 
25   bool has_valid_array = false;
26   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
27   switch (mem_type) {
28     case CEED_MEM_HOST:
29       *need_sync = has_valid_array && !impl->h_array;
30       break;
31     case CEED_MEM_DEVICE:
32       *need_sync = has_valid_array && !impl->d_array;
33       break;
34   }
35 
36   return CEED_ERROR_SUCCESS;
37 }
38 
39 //------------------------------------------------------------------------------
40 // Sync host to device
41 //------------------------------------------------------------------------------
42 static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) {
43   Ceed ceed;
44   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
45   CeedVector_Hip *impl;
46   CeedCallBackend(CeedVectorGetData(vec, &impl));
47 
48   CeedSize length;
49   CeedCallBackend(CeedVectorGetLength(vec, &length));
50   size_t bytes = length * sizeof(CeedScalar);
51 
52   if (!impl->h_array) {
53     // LCOV_EXCL_START
54     return CeedError(ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
55     // LCOV_EXCL_STOP
56   }
57 
58   if (impl->d_array_borrowed) {
59     impl->d_array = impl->d_array_borrowed;
60   } else if (impl->d_array_owned) {
61     impl->d_array = impl->d_array_owned;
62   } else {
63     CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes));
64     impl->d_array = impl->d_array_owned;
65   }
66 
67   CeedCallHip(ceed, hipMemcpy(impl->d_array, impl->h_array, bytes, hipMemcpyHostToDevice));
68 
69   return CEED_ERROR_SUCCESS;
70 }
71 
72 //------------------------------------------------------------------------------
73 // Sync device to host
74 //------------------------------------------------------------------------------
75 static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
76   Ceed ceed;
77   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
78   CeedVector_Hip *impl;
79   CeedCallBackend(CeedVectorGetData(vec, &impl));
80 
81   if (!impl->d_array) {
82     // LCOV_EXCL_START
83     return CeedError(ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
84     // LCOV_EXCL_STOP
85   }
86 
87   if (impl->h_array_borrowed) {
88     impl->h_array = impl->h_array_borrowed;
89   } else if (impl->h_array_owned) {
90     impl->h_array = impl->h_array_owned;
91   } else {
92     CeedSize length;
93     CeedCallBackend(CeedVectorGetLength(vec, &length));
94     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
95     impl->h_array = impl->h_array_owned;
96   }
97 
98   CeedSize length;
99   CeedCallBackend(CeedVectorGetLength(vec, &length));
100   size_t bytes = length * sizeof(CeedScalar);
101   CeedCallHip(ceed, hipMemcpy(impl->h_array, impl->d_array, bytes, hipMemcpyDeviceToHost));
102 
103   return CEED_ERROR_SUCCESS;
104 }
105 
106 //------------------------------------------------------------------------------
107 // Sync arrays
108 //------------------------------------------------------------------------------
109 static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) {
110   // Check whether device/host sync is needed
111   bool need_sync = false;
112   CeedCallBackend(CeedVectorNeedSync_Hip(vec, mem_type, &need_sync));
113   if (!need_sync) return CEED_ERROR_SUCCESS;
114 
115   switch (mem_type) {
116     case CEED_MEM_HOST:
117       return CeedVectorSyncD2H_Hip(vec);
118     case CEED_MEM_DEVICE:
119       return CeedVectorSyncH2D_Hip(vec);
120   }
121   return CEED_ERROR_UNSUPPORTED;
122 }
123 
124 //------------------------------------------------------------------------------
125 // Set all pointers as invalid
126 //------------------------------------------------------------------------------
127 static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) {
128   CeedVector_Hip *impl;
129   CeedCallBackend(CeedVectorGetData(vec, &impl));
130 
131   impl->h_array = NULL;
132   impl->d_array = NULL;
133 
134   return CEED_ERROR_SUCCESS;
135 }
136 
137 //------------------------------------------------------------------------------
138 // Check if CeedVector has any valid pointers
139 //------------------------------------------------------------------------------
140 static inline int CeedVectorHasValidArray_Hip(const CeedVector vec, bool *has_valid_array) {
141   CeedVector_Hip *impl;
142   CeedCallBackend(CeedVectorGetData(vec, &impl));
143 
144   *has_valid_array = !!impl->h_array || !!impl->d_array;
145 
146   return CEED_ERROR_SUCCESS;
147 }
148 
149 //------------------------------------------------------------------------------
150 // Check if has any array of given type
151 //------------------------------------------------------------------------------
152 static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
153   CeedVector_Hip *impl;
154   CeedCallBackend(CeedVectorGetData(vec, &impl));
155 
156   switch (mem_type) {
157     case CEED_MEM_HOST:
158       *has_array_of_type = !!impl->h_array_borrowed || !!impl->h_array_owned;
159       break;
160     case CEED_MEM_DEVICE:
161       *has_array_of_type = !!impl->d_array_borrowed || !!impl->d_array_owned;
162       break;
163   }
164 
165   return CEED_ERROR_SUCCESS;
166 }
167 
168 //------------------------------------------------------------------------------
169 // Check if has borrowed array of given type
170 //------------------------------------------------------------------------------
171 static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
172   CeedVector_Hip *impl;
173   CeedCallBackend(CeedVectorGetData(vec, &impl));
174 
175   switch (mem_type) {
176     case CEED_MEM_HOST:
177       *has_borrowed_array_of_type = !!impl->h_array_borrowed;
178       break;
179     case CEED_MEM_DEVICE:
180       *has_borrowed_array_of_type = !!impl->d_array_borrowed;
181       break;
182   }
183 
184   return CEED_ERROR_SUCCESS;
185 }
186 
187 //------------------------------------------------------------------------------
188 // Set array from host
189 //------------------------------------------------------------------------------
190 static int CeedVectorSetArrayHost_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
191   CeedVector_Hip *impl;
192   CeedCallBackend(CeedVectorGetData(vec, &impl));
193 
194   switch (copy_mode) {
195     case CEED_COPY_VALUES: {
196       CeedSize length;
197       if (!impl->h_array_owned) {
198         CeedCallBackend(CeedVectorGetLength(vec, &length));
199         CeedCallBackend(CeedMalloc(length, &impl->h_array_owned));
200       }
201       impl->h_array_borrowed = NULL;
202       impl->h_array          = impl->h_array_owned;
203       if (array) {
204         CeedSize length;
205         CeedCallBackend(CeedVectorGetLength(vec, &length));
206         size_t bytes = length * sizeof(CeedScalar);
207         memcpy(impl->h_array, array, bytes);
208       }
209     } break;
210     case CEED_OWN_POINTER:
211       CeedCallBackend(CeedFree(&impl->h_array_owned));
212       impl->h_array_owned    = array;
213       impl->h_array_borrowed = NULL;
214       impl->h_array          = array;
215       break;
216     case CEED_USE_POINTER:
217       CeedCallBackend(CeedFree(&impl->h_array_owned));
218       impl->h_array_borrowed = array;
219       impl->h_array          = array;
220       break;
221   }
222 
223   return CEED_ERROR_SUCCESS;
224 }
225 
226 //------------------------------------------------------------------------------
227 // Set array from device
228 //------------------------------------------------------------------------------
229 static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
230   Ceed ceed;
231   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
232   CeedVector_Hip *impl;
233   CeedCallBackend(CeedVectorGetData(vec, &impl));
234 
235   switch (copy_mode) {
236     case CEED_COPY_VALUES: {
237       CeedSize length;
238       CeedCallBackend(CeedVectorGetLength(vec, &length));
239       size_t bytes = length * sizeof(CeedScalar);
240       if (!impl->d_array_owned) {
241         CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes));
242       }
243       impl->d_array_borrowed = NULL;
244       impl->d_array          = impl->d_array_owned;
245       if (array) {
246         CeedCallHip(ceed, hipMemcpy(impl->d_array, array, bytes, hipMemcpyDeviceToDevice));
247       }
248     } break;
249     case CEED_OWN_POINTER:
250       CeedCallHip(ceed, hipFree(impl->d_array_owned));
251       impl->d_array_owned    = array;
252       impl->d_array_borrowed = NULL;
253       impl->d_array          = array;
254       break;
255     case CEED_USE_POINTER:
256       CeedCallHip(ceed, hipFree(impl->d_array_owned));
257       impl->d_array_owned    = NULL;
258       impl->d_array_borrowed = array;
259       impl->d_array          = array;
260       break;
261   }
262 
263   return CEED_ERROR_SUCCESS;
264 }
265 
266 //------------------------------------------------------------------------------
267 // Set the array used by a vector,
268 //   freeing any previously allocated array if applicable
269 //------------------------------------------------------------------------------
270 static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
271   Ceed ceed;
272   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
273   CeedVector_Hip *impl;
274   CeedCallBackend(CeedVectorGetData(vec, &impl));
275 
276   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
277   switch (mem_type) {
278     case CEED_MEM_HOST:
279       return CeedVectorSetArrayHost_Hip(vec, copy_mode, array);
280     case CEED_MEM_DEVICE:
281       return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array);
282   }
283 
284   return CEED_ERROR_UNSUPPORTED;
285 }
286 
287 //------------------------------------------------------------------------------
288 // Set host array to value
289 //------------------------------------------------------------------------------
290 static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedInt length, CeedScalar val) {
291   for (int i = 0; i < length; i++) h_array[i] = val;
292   return CEED_ERROR_SUCCESS;
293 }
294 
295 //------------------------------------------------------------------------------
296 // Set device array to value (impl in .hip file)
297 //------------------------------------------------------------------------------
298 int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedInt length, CeedScalar val);
299 
300 //------------------------------------------------------------------------------
301 // Set a vector to a value,
302 //------------------------------------------------------------------------------
303 static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
304   Ceed ceed;
305   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
306   CeedVector_Hip *impl;
307   CeedCallBackend(CeedVectorGetData(vec, &impl));
308   CeedSize length;
309   CeedCallBackend(CeedVectorGetLength(vec, &length));
310 
311   // Set value for synced device/host array
312   if (!impl->d_array && !impl->h_array) {
313     if (impl->d_array_borrowed) {
314       impl->d_array = impl->d_array_borrowed;
315     } else if (impl->h_array_borrowed) {
316       impl->h_array = impl->h_array_borrowed;
317     } else if (impl->d_array_owned) {
318       impl->d_array = impl->d_array_owned;
319     } else if (impl->h_array_owned) {
320       impl->h_array = impl->h_array_owned;
321     } else {
322       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
323     }
324   }
325   if (impl->d_array) {
326     CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val));
327   }
328   if (impl->h_array) {
329     CeedCallBackend(CeedHostSetValue_Hip(impl->h_array, length, val));
330   }
331 
332   return CEED_ERROR_SUCCESS;
333 }
334 
335 //------------------------------------------------------------------------------
336 // Vector Take Array
337 //------------------------------------------------------------------------------
338 static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
339   Ceed ceed;
340   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
341   CeedVector_Hip *impl;
342   CeedCallBackend(CeedVectorGetData(vec, &impl));
343 
344   // Sync array to requested mem_type
345   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
346 
347   // Update pointer
348   switch (mem_type) {
349     case CEED_MEM_HOST:
350       (*array)               = impl->h_array_borrowed;
351       impl->h_array_borrowed = NULL;
352       impl->h_array          = NULL;
353       break;
354     case CEED_MEM_DEVICE:
355       (*array)               = impl->d_array_borrowed;
356       impl->d_array_borrowed = NULL;
357       impl->d_array          = NULL;
358       break;
359   }
360 
361   return CEED_ERROR_SUCCESS;
362 }
363 
364 //------------------------------------------------------------------------------
365 // Core logic for array syncronization for GetArray.
366 //   If a different memory type is most up to date, this will perform a copy
367 //------------------------------------------------------------------------------
368 static int CeedVectorGetArrayCore_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
369   Ceed ceed;
370   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
371   CeedVector_Hip *impl;
372   CeedCallBackend(CeedVectorGetData(vec, &impl));
373 
374   // Sync array to requested mem_type
375   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
376 
377   // Update pointer
378   switch (mem_type) {
379     case CEED_MEM_HOST:
380       *array = impl->h_array;
381       break;
382     case CEED_MEM_DEVICE:
383       *array = impl->d_array;
384       break;
385   }
386 
387   return CEED_ERROR_SUCCESS;
388 }
389 
390 //------------------------------------------------------------------------------
391 // Get read-only access to a vector via the specified mem_type
392 //------------------------------------------------------------------------------
393 static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
394   return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array);
395 }
396 
397 //------------------------------------------------------------------------------
398 // Get read/write access to a vector via the specified mem_type
399 //------------------------------------------------------------------------------
400 static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
401   CeedVector_Hip *impl;
402   CeedCallBackend(CeedVectorGetData(vec, &impl));
403 
404   CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, mem_type, array));
405 
406   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
407   switch (mem_type) {
408     case CEED_MEM_HOST:
409       impl->h_array = *array;
410       break;
411     case CEED_MEM_DEVICE:
412       impl->d_array = *array;
413       break;
414   }
415 
416   return CEED_ERROR_SUCCESS;
417 }
418 
419 //------------------------------------------------------------------------------
420 // Get write access to a vector via the specified mem_type
421 //------------------------------------------------------------------------------
422 static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
423   CeedVector_Hip *impl;
424   CeedCallBackend(CeedVectorGetData(vec, &impl));
425 
426   bool has_array_of_type = true;
427   CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type));
428   if (!has_array_of_type) {
429     // Allocate if array is not yet allocated
430     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
431   } else {
432     // Select dirty array
433     switch (mem_type) {
434       case CEED_MEM_HOST:
435         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
436         else impl->h_array = impl->h_array_owned;
437         break;
438       case CEED_MEM_DEVICE:
439         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
440         else impl->d_array = impl->d_array_owned;
441     }
442   }
443 
444   return CeedVectorGetArray_Hip(vec, mem_type, array);
445 }
446 
447 //------------------------------------------------------------------------------
448 // Get the norm of a CeedVector
449 //------------------------------------------------------------------------------
450 static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *norm) {
451   Ceed ceed;
452   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
453   CeedVector_Hip *impl;
454   CeedCallBackend(CeedVectorGetData(vec, &impl));
455   CeedSize length;
456   CeedCallBackend(CeedVectorGetLength(vec, &length));
457   hipblasHandle_t handle;
458   CeedCallBackend(CeedHipGetHipblasHandle(ceed, &handle));
459 
460   // Compute norm
461   const CeedScalar *d_array;
462   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
463   switch (type) {
464     case CEED_NORM_1: {
465       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
466         CeedCallHipblas(ceed, hipblasSasum(handle, length, (float *)d_array, 1, (float *)norm));
467       } else {
468         CeedCallHipblas(ceed, hipblasDasum(handle, length, (double *)d_array, 1, (double *)norm));
469       }
470       break;
471     }
472     case CEED_NORM_2: {
473       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
474         CeedCallHipblas(ceed, hipblasSnrm2(handle, length, (float *)d_array, 1, (float *)norm));
475       } else {
476         CeedCallHipblas(ceed, hipblasDnrm2(handle, length, (double *)d_array, 1, (double *)norm));
477       }
478       break;
479     }
480     case CEED_NORM_MAX: {
481       CeedInt indx;
482       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
483         CeedCallHipblas(ceed, hipblasIsamax(handle, length, (float *)d_array, 1, &indx));
484       } else {
485         CeedCallHipblas(ceed, hipblasIdamax(handle, length, (double *)d_array, 1, &indx));
486       }
487       CeedScalar normNoAbs;
488       CeedCallHip(ceed, hipMemcpy(&normNoAbs, impl->d_array + indx - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
489       *norm = fabs(normNoAbs);
490       break;
491     }
492   }
493   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
494 
495   return CEED_ERROR_SUCCESS;
496 }
497 
498 //------------------------------------------------------------------------------
499 // Take reciprocal of a vector on host
500 //------------------------------------------------------------------------------
501 static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedInt length) {
502   for (int i = 0; i < length; i++) {
503     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
504   }
505   return CEED_ERROR_SUCCESS;
506 }
507 
508 //------------------------------------------------------------------------------
509 // Take reciprocal of a vector on device (impl in .cu file)
510 //------------------------------------------------------------------------------
511 int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedInt length);
512 
513 //------------------------------------------------------------------------------
514 // Take reciprocal of a vector
515 //------------------------------------------------------------------------------
516 static int CeedVectorReciprocal_Hip(CeedVector vec) {
517   Ceed ceed;
518   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
519   CeedVector_Hip *impl;
520   CeedCallBackend(CeedVectorGetData(vec, &impl));
521   CeedSize length;
522   CeedCallBackend(CeedVectorGetLength(vec, &length));
523 
524   // Set value for synced device/host array
525   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Hip(impl->d_array, length));
526   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length));
527 
528   return CEED_ERROR_SUCCESS;
529 }
530 
531 //------------------------------------------------------------------------------
532 // Compute x = alpha x on the host
533 //------------------------------------------------------------------------------
534 static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedInt length) {
535   for (int i = 0; i < length; i++) x_array[i] *= alpha;
536   return CEED_ERROR_SUCCESS;
537 }
538 
539 //------------------------------------------------------------------------------
540 // Compute x = alpha x on device (impl in .cu file)
541 //------------------------------------------------------------------------------
542 int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedInt length);
543 
544 //------------------------------------------------------------------------------
545 // Compute x = alpha x
546 //------------------------------------------------------------------------------
547 static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
548   Ceed ceed;
549   CeedCallBackend(CeedVectorGetCeed(x, &ceed));
550   CeedVector_Hip *x_impl;
551   CeedCallBackend(CeedVectorGetData(x, &x_impl));
552   CeedSize length;
553   CeedCallBackend(CeedVectorGetLength(x, &length));
554 
555   // Set value for synced device/host array
556   if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Hip(x_impl->d_array, alpha, length));
557   if (x_impl->h_array) CeedCallBackend(CeedHostScale_Hip(x_impl->h_array, alpha, length));
558 
559   return CEED_ERROR_SUCCESS;
560 }
561 
562 //------------------------------------------------------------------------------
563 // Compute y = alpha x + y on the host
564 //------------------------------------------------------------------------------
565 static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedInt length) {
566   for (int i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
567   return CEED_ERROR_SUCCESS;
568 }
569 
570 //------------------------------------------------------------------------------
571 // Compute y = alpha x + y on device (impl in .cu file)
572 //------------------------------------------------------------------------------
573 int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedInt length);
574 
575 //------------------------------------------------------------------------------
576 // Compute y = alpha x + y
577 //------------------------------------------------------------------------------
578 static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
579   Ceed ceed;
580   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
581   CeedVector_Hip *y_impl, *x_impl;
582   CeedCallBackend(CeedVectorGetData(y, &y_impl));
583   CeedCallBackend(CeedVectorGetData(x, &x_impl));
584   CeedSize length;
585   CeedCallBackend(CeedVectorGetLength(y, &length));
586 
587   // Set value for synced device/host array
588   if (y_impl->d_array) {
589     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
590     CeedCallBackend(CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length));
591   }
592   if (y_impl->h_array) {
593     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
594     CeedCallBackend(CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length));
595   }
596 
597   return CEED_ERROR_SUCCESS;
598 }
599 //------------------------------------------------------------------------------
600 // Compute y = alpha x + beta y on the host
601 //------------------------------------------------------------------------------
602 static int CeedHostAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedInt length) {
603   for (int i = 0; i < length; i++) y_array[i] += alpha * x_array[i] + beta * y_array[i];
604   return CEED_ERROR_SUCCESS;
605 }
606 
607 //------------------------------------------------------------------------------
608 // Compute y = alpha x + beta y on device (impl in .cu file)
609 //------------------------------------------------------------------------------
610 int CeedDeviceAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedInt length);
611 
612 //------------------------------------------------------------------------------
613 // Compute y = alpha x + beta y
614 //------------------------------------------------------------------------------
615 static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
616   Ceed ceed;
617   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
618   CeedVector_Hip *y_impl, *x_impl;
619   CeedCallBackend(CeedVectorGetData(y, &y_impl));
620   CeedCallBackend(CeedVectorGetData(x, &x_impl));
621   CeedSize length;
622   CeedCallBackend(CeedVectorGetLength(y, &length));
623 
624   // Set value for synced device/host array
625   if (y_impl->d_array) {
626     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
627     CeedCallBackend(CeedDeviceAXPBY_Hip(y_impl->d_array, alpha, beta, x_impl->d_array, length));
628   }
629   if (y_impl->h_array) {
630     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
631     CeedCallBackend(CeedHostAXPBY_Hip(y_impl->h_array, alpha, beta, x_impl->h_array, length));
632   }
633 
634   return CEED_ERROR_SUCCESS;
635 }
636 
637 //------------------------------------------------------------------------------
638 // Compute the pointwise multiplication w = x .* y on the host
639 //------------------------------------------------------------------------------
640 static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedInt length) {
641   for (int i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
642   return CEED_ERROR_SUCCESS;
643 }
644 
645 //------------------------------------------------------------------------------
646 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
647 //------------------------------------------------------------------------------
648 int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedInt length);
649 
650 //------------------------------------------------------------------------------
651 // Compute the pointwise multiplication w = x .* y
652 //------------------------------------------------------------------------------
653 static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y) {
654   Ceed ceed;
655   CeedCallBackend(CeedVectorGetCeed(w, &ceed));
656   CeedVector_Hip *w_impl, *x_impl, *y_impl;
657   CeedCallBackend(CeedVectorGetData(w, &w_impl));
658   CeedCallBackend(CeedVectorGetData(x, &x_impl));
659   CeedCallBackend(CeedVectorGetData(y, &y_impl));
660   CeedSize length;
661   CeedCallBackend(CeedVectorGetLength(w, &length));
662 
663   // Set value for synced device/host array
664   if (!w_impl->d_array && !w_impl->h_array) {
665     CeedCallBackend(CeedVectorSetValue(w, 0.0));
666   }
667   if (w_impl->d_array) {
668     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
669     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
670     CeedCallBackend(CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
671   }
672   if (w_impl->h_array) {
673     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
674     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
675     CeedCallBackend(CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
676   }
677 
678   return CEED_ERROR_SUCCESS;
679 }
680 
681 //------------------------------------------------------------------------------
682 // Destroy the vector
683 //------------------------------------------------------------------------------
684 static int CeedVectorDestroy_Hip(const CeedVector vec) {
685   Ceed ceed;
686   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
687   CeedVector_Hip *impl;
688   CeedCallBackend(CeedVectorGetData(vec, &impl));
689 
690   CeedCallHip(ceed, hipFree(impl->d_array_owned));
691   CeedCallBackend(CeedFree(&impl->h_array_owned));
692   CeedCallBackend(CeedFree(&impl));
693 
694   return CEED_ERROR_SUCCESS;
695 }
696 
697 //------------------------------------------------------------------------------
698 // Create a vector of the specified length (does not allocate memory)
699 //------------------------------------------------------------------------------
700 int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
701   CeedVector_Hip *impl;
702   Ceed            ceed;
703   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
704 
705   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Hip));
706   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Hip));
707   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Hip));
708   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Hip));
709   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", CeedVectorSetValue_Hip));
710   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Hip));
711   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Hip));
712   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Hip));
713   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Hip));
714   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Hip));
715   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Hip));
716   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", CeedVectorScale_Hip));
717   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", CeedVectorAXPY_Hip));
718   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", CeedVectorAXPBY_Hip));
719   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Hip));
720   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Hip));
721 
722   CeedCallBackend(CeedCalloc(1, &impl));
723   CeedCallBackend(CeedVectorSetData(vec, impl));
724 
725   return CEED_ERROR_SUCCESS;
726 }
727