xref: /libCEED/backends/hip-ref/ceed-hip-ref-vector.c (revision 404fb0e2c1afa90f8620af98e57297f952bca7c5)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/backend.h>
9 #include <ceed/ceed.h>
10 #include <hip/hip_runtime.h>
11 #include <math.h>
12 #include <string.h>
13 
14 #include "ceed-hip-ref.h"
15 
16 //------------------------------------------------------------------------------
17 // Check if host/device sync is needed
18 //------------------------------------------------------------------------------
19 static inline int CeedVectorNeedSync_Hip(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
20   CeedVector_Hip *impl;
21   CeedCallBackend(CeedVectorGetData(vec, &impl));
22 
23   bool has_valid_array = false;
24   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
25   switch (mem_type) {
26     case CEED_MEM_HOST:
27       *need_sync = has_valid_array && !impl->h_array;
28       break;
29     case CEED_MEM_DEVICE:
30       *need_sync = has_valid_array && !impl->d_array;
31       break;
32   }
33 
34   return CEED_ERROR_SUCCESS;
35 }
36 
37 //------------------------------------------------------------------------------
38 // Sync host to device
39 //------------------------------------------------------------------------------
40 static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) {
41   Ceed ceed;
42   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
43   CeedVector_Hip *impl;
44   CeedCallBackend(CeedVectorGetData(vec, &impl));
45 
46   CeedSize length;
47   CeedCallBackend(CeedVectorGetLength(vec, &length));
48   size_t bytes = length * sizeof(CeedScalar);
49 
50   if (!impl->h_array) {
51     // LCOV_EXCL_START
52     return CeedError(ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
53     // LCOV_EXCL_STOP
54   }
55 
56   if (impl->d_array_borrowed) {
57     impl->d_array = impl->d_array_borrowed;
58   } else if (impl->d_array_owned) {
59     impl->d_array = impl->d_array_owned;
60   } else {
61     CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes));
62     impl->d_array = impl->d_array_owned;
63   }
64 
65   CeedCallHip(ceed, hipMemcpy(impl->d_array, impl->h_array, bytes, hipMemcpyHostToDevice));
66 
67   return CEED_ERROR_SUCCESS;
68 }
69 
70 //------------------------------------------------------------------------------
71 // Sync device to host
72 //------------------------------------------------------------------------------
73 static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
74   Ceed ceed;
75   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
76   CeedVector_Hip *impl;
77   CeedCallBackend(CeedVectorGetData(vec, &impl));
78 
79   if (!impl->d_array) {
80     // LCOV_EXCL_START
81     return CeedError(ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
82     // LCOV_EXCL_STOP
83   }
84 
85   if (impl->h_array_borrowed) {
86     impl->h_array = impl->h_array_borrowed;
87   } else if (impl->h_array_owned) {
88     impl->h_array = impl->h_array_owned;
89   } else {
90     CeedSize length;
91     CeedCallBackend(CeedVectorGetLength(vec, &length));
92     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
93     impl->h_array = impl->h_array_owned;
94   }
95 
96   CeedSize length;
97   CeedCallBackend(CeedVectorGetLength(vec, &length));
98   size_t bytes = length * sizeof(CeedScalar);
99   CeedCallHip(ceed, hipMemcpy(impl->h_array, impl->d_array, bytes, hipMemcpyDeviceToHost));
100 
101   return CEED_ERROR_SUCCESS;
102 }
103 
104 //------------------------------------------------------------------------------
105 // Sync arrays
106 //------------------------------------------------------------------------------
107 static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) {
108   // Check whether device/host sync is needed
109   bool need_sync = false;
110   CeedCallBackend(CeedVectorNeedSync_Hip(vec, mem_type, &need_sync));
111   if (!need_sync) return CEED_ERROR_SUCCESS;
112 
113   switch (mem_type) {
114     case CEED_MEM_HOST:
115       return CeedVectorSyncD2H_Hip(vec);
116     case CEED_MEM_DEVICE:
117       return CeedVectorSyncH2D_Hip(vec);
118   }
119   return CEED_ERROR_UNSUPPORTED;
120 }
121 
122 //------------------------------------------------------------------------------
123 // Set all pointers as invalid
124 //------------------------------------------------------------------------------
125 static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) {
126   CeedVector_Hip *impl;
127   CeedCallBackend(CeedVectorGetData(vec, &impl));
128 
129   impl->h_array = NULL;
130   impl->d_array = NULL;
131 
132   return CEED_ERROR_SUCCESS;
133 }
134 
135 //------------------------------------------------------------------------------
136 // Check if CeedVector has any valid pointers
137 //------------------------------------------------------------------------------
138 static inline int CeedVectorHasValidArray_Hip(const CeedVector vec, bool *has_valid_array) {
139   CeedVector_Hip *impl;
140   CeedCallBackend(CeedVectorGetData(vec, &impl));
141 
142   *has_valid_array = !!impl->h_array || !!impl->d_array;
143 
144   return CEED_ERROR_SUCCESS;
145 }
146 
147 //------------------------------------------------------------------------------
148 // Check if has any array of given type
149 //------------------------------------------------------------------------------
150 static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
151   CeedVector_Hip *impl;
152   CeedCallBackend(CeedVectorGetData(vec, &impl));
153 
154   switch (mem_type) {
155     case CEED_MEM_HOST:
156       *has_array_of_type = !!impl->h_array_borrowed || !!impl->h_array_owned;
157       break;
158     case CEED_MEM_DEVICE:
159       *has_array_of_type = !!impl->d_array_borrowed || !!impl->d_array_owned;
160       break;
161   }
162 
163   return CEED_ERROR_SUCCESS;
164 }
165 
166 //------------------------------------------------------------------------------
167 // Check if has borrowed array of given type
168 //------------------------------------------------------------------------------
169 static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
170   CeedVector_Hip *impl;
171   CeedCallBackend(CeedVectorGetData(vec, &impl));
172 
173   switch (mem_type) {
174     case CEED_MEM_HOST:
175       *has_borrowed_array_of_type = !!impl->h_array_borrowed;
176       break;
177     case CEED_MEM_DEVICE:
178       *has_borrowed_array_of_type = !!impl->d_array_borrowed;
179       break;
180   }
181 
182   return CEED_ERROR_SUCCESS;
183 }
184 
185 //------------------------------------------------------------------------------
186 // Set array from host
187 //------------------------------------------------------------------------------
188 static int CeedVectorSetArrayHost_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
189   CeedVector_Hip *impl;
190   CeedCallBackend(CeedVectorGetData(vec, &impl));
191 
192   switch (copy_mode) {
193     case CEED_COPY_VALUES: {
194       CeedSize length;
195       if (!impl->h_array_owned) {
196         CeedCallBackend(CeedVectorGetLength(vec, &length));
197         CeedCallBackend(CeedMalloc(length, &impl->h_array_owned));
198       }
199       impl->h_array_borrowed = NULL;
200       impl->h_array          = impl->h_array_owned;
201       if (array) {
202         CeedSize length;
203         CeedCallBackend(CeedVectorGetLength(vec, &length));
204         size_t bytes = length * sizeof(CeedScalar);
205         memcpy(impl->h_array, array, bytes);
206       }
207     } break;
208     case CEED_OWN_POINTER:
209       CeedCallBackend(CeedFree(&impl->h_array_owned));
210       impl->h_array_owned    = array;
211       impl->h_array_borrowed = NULL;
212       impl->h_array          = array;
213       break;
214     case CEED_USE_POINTER:
215       CeedCallBackend(CeedFree(&impl->h_array_owned));
216       impl->h_array_borrowed = array;
217       impl->h_array          = array;
218       break;
219   }
220 
221   return CEED_ERROR_SUCCESS;
222 }
223 
224 //------------------------------------------------------------------------------
225 // Set array from device
226 //------------------------------------------------------------------------------
227 static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
228   Ceed ceed;
229   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
230   CeedVector_Hip *impl;
231   CeedCallBackend(CeedVectorGetData(vec, &impl));
232 
233   switch (copy_mode) {
234     case CEED_COPY_VALUES: {
235       CeedSize length;
236       CeedCallBackend(CeedVectorGetLength(vec, &length));
237       size_t bytes = length * sizeof(CeedScalar);
238       if (!impl->d_array_owned) {
239         CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes));
240       }
241       impl->d_array_borrowed = NULL;
242       impl->d_array          = impl->d_array_owned;
243       if (array) {
244         CeedCallHip(ceed, hipMemcpy(impl->d_array, array, bytes, hipMemcpyDeviceToDevice));
245       }
246     } break;
247     case CEED_OWN_POINTER:
248       CeedCallHip(ceed, hipFree(impl->d_array_owned));
249       impl->d_array_owned    = array;
250       impl->d_array_borrowed = NULL;
251       impl->d_array          = array;
252       break;
253     case CEED_USE_POINTER:
254       CeedCallHip(ceed, hipFree(impl->d_array_owned));
255       impl->d_array_owned    = NULL;
256       impl->d_array_borrowed = array;
257       impl->d_array          = array;
258       break;
259   }
260 
261   return CEED_ERROR_SUCCESS;
262 }
263 
264 //------------------------------------------------------------------------------
265 // Set the array used by a vector,
266 //   freeing any previously allocated array if applicable
267 //------------------------------------------------------------------------------
268 static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
269   Ceed ceed;
270   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
271   CeedVector_Hip *impl;
272   CeedCallBackend(CeedVectorGetData(vec, &impl));
273 
274   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
275   switch (mem_type) {
276     case CEED_MEM_HOST:
277       return CeedVectorSetArrayHost_Hip(vec, copy_mode, array);
278     case CEED_MEM_DEVICE:
279       return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array);
280   }
281 
282   return CEED_ERROR_UNSUPPORTED;
283 }
284 
285 //------------------------------------------------------------------------------
286 // Set host array to value
287 //------------------------------------------------------------------------------
288 static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedInt length, CeedScalar val) {
289   for (int i = 0; i < length; i++) h_array[i] = val;
290   return CEED_ERROR_SUCCESS;
291 }
292 
293 //------------------------------------------------------------------------------
294 // Set device array to value (impl in .hip file)
295 //------------------------------------------------------------------------------
296 int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedInt length, CeedScalar val);
297 
298 //------------------------------------------------------------------------------
299 // Set a vector to a value,
300 //------------------------------------------------------------------------------
301 static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
302   Ceed ceed;
303   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
304   CeedVector_Hip *impl;
305   CeedCallBackend(CeedVectorGetData(vec, &impl));
306   CeedSize length;
307   CeedCallBackend(CeedVectorGetLength(vec, &length));
308 
309   // Set value for synced device/host array
310   if (!impl->d_array && !impl->h_array) {
311     if (impl->d_array_borrowed) {
312       impl->d_array = impl->d_array_borrowed;
313     } else if (impl->h_array_borrowed) {
314       impl->h_array = impl->h_array_borrowed;
315     } else if (impl->d_array_owned) {
316       impl->d_array = impl->d_array_owned;
317     } else if (impl->h_array_owned) {
318       impl->h_array = impl->h_array_owned;
319     } else {
320       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
321     }
322   }
323   if (impl->d_array) {
324     CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val));
325   }
326   if (impl->h_array) {
327     CeedCallBackend(CeedHostSetValue_Hip(impl->h_array, length, val));
328   }
329 
330   return CEED_ERROR_SUCCESS;
331 }
332 
333 //------------------------------------------------------------------------------
334 // Vector Take Array
335 //------------------------------------------------------------------------------
336 static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
337   Ceed ceed;
338   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
339   CeedVector_Hip *impl;
340   CeedCallBackend(CeedVectorGetData(vec, &impl));
341 
342   // Sync array to requested mem_type
343   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
344 
345   // Update pointer
346   switch (mem_type) {
347     case CEED_MEM_HOST:
348       (*array)               = impl->h_array_borrowed;
349       impl->h_array_borrowed = NULL;
350       impl->h_array          = NULL;
351       break;
352     case CEED_MEM_DEVICE:
353       (*array)               = impl->d_array_borrowed;
354       impl->d_array_borrowed = NULL;
355       impl->d_array          = NULL;
356       break;
357   }
358 
359   return CEED_ERROR_SUCCESS;
360 }
361 
362 //------------------------------------------------------------------------------
363 // Core logic for array syncronization for GetArray.
364 //   If a different memory type is most up to date, this will perform a copy
365 //------------------------------------------------------------------------------
366 static int CeedVectorGetArrayCore_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
367   Ceed ceed;
368   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
369   CeedVector_Hip *impl;
370   CeedCallBackend(CeedVectorGetData(vec, &impl));
371 
372   // Sync array to requested mem_type
373   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
374 
375   // Update pointer
376   switch (mem_type) {
377     case CEED_MEM_HOST:
378       *array = impl->h_array;
379       break;
380     case CEED_MEM_DEVICE:
381       *array = impl->d_array;
382       break;
383   }
384 
385   return CEED_ERROR_SUCCESS;
386 }
387 
388 //------------------------------------------------------------------------------
389 // Get read-only access to a vector via the specified mem_type
390 //------------------------------------------------------------------------------
391 static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
392   return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array);
393 }
394 
395 //------------------------------------------------------------------------------
396 // Get read/write access to a vector via the specified mem_type
397 //------------------------------------------------------------------------------
398 static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
399   CeedVector_Hip *impl;
400   CeedCallBackend(CeedVectorGetData(vec, &impl));
401 
402   CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, mem_type, array));
403 
404   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
405   switch (mem_type) {
406     case CEED_MEM_HOST:
407       impl->h_array = *array;
408       break;
409     case CEED_MEM_DEVICE:
410       impl->d_array = *array;
411       break;
412   }
413 
414   return CEED_ERROR_SUCCESS;
415 }
416 
417 //------------------------------------------------------------------------------
418 // Get write access to a vector via the specified mem_type
419 //------------------------------------------------------------------------------
420 static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
421   CeedVector_Hip *impl;
422   CeedCallBackend(CeedVectorGetData(vec, &impl));
423 
424   bool has_array_of_type = true;
425   CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type));
426   if (!has_array_of_type) {
427     // Allocate if array is not yet allocated
428     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
429   } else {
430     // Select dirty array
431     switch (mem_type) {
432       case CEED_MEM_HOST:
433         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
434         else impl->h_array = impl->h_array_owned;
435         break;
436       case CEED_MEM_DEVICE:
437         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
438         else impl->d_array = impl->d_array_owned;
439     }
440   }
441 
442   return CeedVectorGetArray_Hip(vec, mem_type, array);
443 }
444 
445 //------------------------------------------------------------------------------
446 // Get the norm of a CeedVector
447 //------------------------------------------------------------------------------
448 static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *norm) {
449   Ceed ceed;
450   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
451   CeedVector_Hip *impl;
452   CeedCallBackend(CeedVectorGetData(vec, &impl));
453   CeedSize length;
454   CeedCallBackend(CeedVectorGetLength(vec, &length));
455   hipblasHandle_t handle;
456   CeedCallBackend(CeedHipGetHipblasHandle(ceed, &handle));
457 
458   // Compute norm
459   const CeedScalar *d_array;
460   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
461   switch (type) {
462     case CEED_NORM_1: {
463       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
464         CeedCallHipblas(ceed, hipblasSasum(handle, length, (float *)d_array, 1, (float *)norm));
465       } else {
466         CeedCallHipblas(ceed, hipblasDasum(handle, length, (double *)d_array, 1, (double *)norm));
467       }
468       break;
469     }
470     case CEED_NORM_2: {
471       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
472         CeedCallHipblas(ceed, hipblasSnrm2(handle, length, (float *)d_array, 1, (float *)norm));
473       } else {
474         CeedCallHipblas(ceed, hipblasDnrm2(handle, length, (double *)d_array, 1, (double *)norm));
475       }
476       break;
477     }
478     case CEED_NORM_MAX: {
479       CeedInt indx;
480       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
481         CeedCallHipblas(ceed, hipblasIsamax(handle, length, (float *)d_array, 1, &indx));
482       } else {
483         CeedCallHipblas(ceed, hipblasIdamax(handle, length, (double *)d_array, 1, &indx));
484       }
485       CeedScalar normNoAbs;
486       CeedCallHip(ceed, hipMemcpy(&normNoAbs, impl->d_array + indx - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
487       *norm = fabs(normNoAbs);
488       break;
489     }
490   }
491   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
492 
493   return CEED_ERROR_SUCCESS;
494 }
495 
496 //------------------------------------------------------------------------------
497 // Take reciprocal of a vector on host
498 //------------------------------------------------------------------------------
499 static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedInt length) {
500   for (int i = 0; i < length; i++) {
501     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
502   }
503   return CEED_ERROR_SUCCESS;
504 }
505 
506 //------------------------------------------------------------------------------
507 // Take reciprocal of a vector on device (impl in .cu file)
508 //------------------------------------------------------------------------------
509 int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedInt length);
510 
511 //------------------------------------------------------------------------------
512 // Take reciprocal of a vector
513 //------------------------------------------------------------------------------
514 static int CeedVectorReciprocal_Hip(CeedVector vec) {
515   Ceed ceed;
516   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
517   CeedVector_Hip *impl;
518   CeedCallBackend(CeedVectorGetData(vec, &impl));
519   CeedSize length;
520   CeedCallBackend(CeedVectorGetLength(vec, &length));
521 
522   // Set value for synced device/host array
523   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Hip(impl->d_array, length));
524   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length));
525 
526   return CEED_ERROR_SUCCESS;
527 }
528 
529 //------------------------------------------------------------------------------
530 // Compute x = alpha x on the host
531 //------------------------------------------------------------------------------
532 static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedInt length) {
533   for (int i = 0; i < length; i++) x_array[i] *= alpha;
534   return CEED_ERROR_SUCCESS;
535 }
536 
537 //------------------------------------------------------------------------------
538 // Compute x = alpha x on device (impl in .cu file)
539 //------------------------------------------------------------------------------
540 int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedInt length);
541 
542 //------------------------------------------------------------------------------
543 // Compute x = alpha x
544 //------------------------------------------------------------------------------
545 static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
546   Ceed ceed;
547   CeedCallBackend(CeedVectorGetCeed(x, &ceed));
548   CeedVector_Hip *x_impl;
549   CeedCallBackend(CeedVectorGetData(x, &x_impl));
550   CeedSize length;
551   CeedCallBackend(CeedVectorGetLength(x, &length));
552 
553   // Set value for synced device/host array
554   if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Hip(x_impl->d_array, alpha, length));
555   if (x_impl->h_array) CeedCallBackend(CeedHostScale_Hip(x_impl->h_array, alpha, length));
556 
557   return CEED_ERROR_SUCCESS;
558 }
559 
560 //------------------------------------------------------------------------------
561 // Compute y = alpha x + y on the host
562 //------------------------------------------------------------------------------
563 static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedInt length) {
564   for (int i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
565   return CEED_ERROR_SUCCESS;
566 }
567 
568 //------------------------------------------------------------------------------
569 // Compute y = alpha x + y on device (impl in .cu file)
570 //------------------------------------------------------------------------------
571 int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedInt length);
572 
573 //------------------------------------------------------------------------------
574 // Compute y = alpha x + y
575 //------------------------------------------------------------------------------
576 static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
577   Ceed ceed;
578   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
579   CeedVector_Hip *y_impl, *x_impl;
580   CeedCallBackend(CeedVectorGetData(y, &y_impl));
581   CeedCallBackend(CeedVectorGetData(x, &x_impl));
582   CeedSize length;
583   CeedCallBackend(CeedVectorGetLength(y, &length));
584 
585   // Set value for synced device/host array
586   if (y_impl->d_array) {
587     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
588     CeedCallBackend(CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length));
589   }
590   if (y_impl->h_array) {
591     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
592     CeedCallBackend(CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length));
593   }
594 
595   return CEED_ERROR_SUCCESS;
596 }
597 //------------------------------------------------------------------------------
598 // Compute y = alpha x + beta y on the host
599 //------------------------------------------------------------------------------
600 static int CeedHostAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedInt length) {
601   for (int i = 0; i < length; i++) y_array[i] += alpha * x_array[i] + beta * y_array[i];
602   return CEED_ERROR_SUCCESS;
603 }
604 
605 //------------------------------------------------------------------------------
606 // Compute y = alpha x + beta y on device (impl in .cu file)
607 //------------------------------------------------------------------------------
608 int CeedDeviceAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedInt length);
609 
610 //------------------------------------------------------------------------------
611 // Compute y = alpha x + beta y
612 //------------------------------------------------------------------------------
613 static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
614   Ceed ceed;
615   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
616   CeedVector_Hip *y_impl, *x_impl;
617   CeedCallBackend(CeedVectorGetData(y, &y_impl));
618   CeedCallBackend(CeedVectorGetData(x, &x_impl));
619   CeedSize length;
620   CeedCallBackend(CeedVectorGetLength(y, &length));
621 
622   // Set value for synced device/host array
623   if (y_impl->d_array) {
624     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
625     CeedCallBackend(CeedDeviceAXPBY_Hip(y_impl->d_array, alpha, beta, x_impl->d_array, length));
626   }
627   if (y_impl->h_array) {
628     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
629     CeedCallBackend(CeedHostAXPBY_Hip(y_impl->h_array, alpha, beta, x_impl->h_array, length));
630   }
631 
632   return CEED_ERROR_SUCCESS;
633 }
634 
635 //------------------------------------------------------------------------------
636 // Compute the pointwise multiplication w = x .* y on the host
637 //------------------------------------------------------------------------------
638 static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedInt length) {
639   for (int i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
640   return CEED_ERROR_SUCCESS;
641 }
642 
643 //------------------------------------------------------------------------------
644 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
645 //------------------------------------------------------------------------------
646 int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedInt length);
647 
648 //------------------------------------------------------------------------------
649 // Compute the pointwise multiplication w = x .* y
650 //------------------------------------------------------------------------------
651 static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y) {
652   Ceed ceed;
653   CeedCallBackend(CeedVectorGetCeed(w, &ceed));
654   CeedVector_Hip *w_impl, *x_impl, *y_impl;
655   CeedCallBackend(CeedVectorGetData(w, &w_impl));
656   CeedCallBackend(CeedVectorGetData(x, &x_impl));
657   CeedCallBackend(CeedVectorGetData(y, &y_impl));
658   CeedSize length;
659   CeedCallBackend(CeedVectorGetLength(w, &length));
660 
661   // Set value for synced device/host array
662   if (!w_impl->d_array && !w_impl->h_array) {
663     CeedCallBackend(CeedVectorSetValue(w, 0.0));
664   }
665   if (w_impl->d_array) {
666     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
667     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
668     CeedCallBackend(CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
669   }
670   if (w_impl->h_array) {
671     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
672     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
673     CeedCallBackend(CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
674   }
675 
676   return CEED_ERROR_SUCCESS;
677 }
678 
679 //------------------------------------------------------------------------------
680 // Destroy the vector
681 //------------------------------------------------------------------------------
682 static int CeedVectorDestroy_Hip(const CeedVector vec) {
683   Ceed ceed;
684   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
685   CeedVector_Hip *impl;
686   CeedCallBackend(CeedVectorGetData(vec, &impl));
687 
688   CeedCallHip(ceed, hipFree(impl->d_array_owned));
689   CeedCallBackend(CeedFree(&impl->h_array_owned));
690   CeedCallBackend(CeedFree(&impl));
691 
692   return CEED_ERROR_SUCCESS;
693 }
694 
695 //------------------------------------------------------------------------------
696 // Create a vector of the specified length (does not allocate memory)
697 //------------------------------------------------------------------------------
698 int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
699   CeedVector_Hip *impl;
700   Ceed            ceed;
701   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
702 
703   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Hip));
704   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Hip));
705   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Hip));
706   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Hip));
707   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", CeedVectorSetValue_Hip));
708   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Hip));
709   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Hip));
710   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Hip));
711   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Hip));
712   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Hip));
713   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Hip));
714   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", CeedVectorScale_Hip));
715   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", CeedVectorAXPY_Hip));
716   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", CeedVectorAXPBY_Hip));
717   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Hip));
718   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Hip));
719 
720   CeedCallBackend(CeedCalloc(1, &impl));
721   CeedCallBackend(CeedVectorSetData(vec, impl));
722 
723   return CEED_ERROR_SUCCESS;
724 }
725