xref: /libCEED/backends/hip-ref/ceed-hip-ref-vector.c (revision 7ed177db354372d7a27bd4ae30dbb2e96a56ab87)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed/ceed.h>
18 #include <ceed/backend.h>
19 #include <hip/hip_runtime.h>
20 #include <hipblas.h>
21 #include <math.h>
22 #include <string.h>
23 #include "ceed-hip-ref.h"
24 
25 //------------------------------------------------------------------------------
26 // * Bytes used
27 //------------------------------------------------------------------------------
28 static inline size_t bytes(const CeedVector vec) {
29   int ierr;
30   CeedInt length;
31   ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
32   return length * sizeof(CeedScalar);
33 }
34 
35 //------------------------------------------------------------------------------
36 // Sync host to device
37 //------------------------------------------------------------------------------
38 static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) {
39   int ierr;
40   Ceed ceed;
41   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
42   CeedVector_Hip *impl;
43   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
44 
45   if (!impl->h_array)
46     // LCOV_EXCL_START
47     return CeedError(ceed, CEED_ERROR_BACKEND,
48                      "No valid host data to sync to device");
49   // LCOV_EXCL_STOP
50 
51   if (impl->d_array_borrowed) {
52     impl->d_array = impl->d_array_borrowed;
53   } else if (impl->d_array_owned) {
54     impl->d_array = impl->d_array_owned;
55   } else {
56     ierr = hipMalloc((void **)&impl->d_array_owned, bytes(vec));
57     CeedChk_Hip(ceed, ierr);
58     impl->d_array = impl->d_array_owned;
59   }
60 
61   ierr = hipMemcpy(impl->d_array, impl->h_array, bytes(vec),
62                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
63 
64   return CEED_ERROR_SUCCESS;
65 }
66 
67 //------------------------------------------------------------------------------
68 // Sync device to host
69 //------------------------------------------------------------------------------
70 static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
71   int ierr;
72   Ceed ceed;
73   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
74   CeedVector_Hip *impl;
75   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
76 
77   if (!impl->d_array)
78     // LCOV_EXCL_START
79     return CeedError(ceed, CEED_ERROR_BACKEND,
80                      "No valid device data to sync to host");
81   // LCOV_EXCL_STOP
82 
83   if (impl->h_array_borrowed) {
84     impl->h_array = impl->h_array_borrowed;
85   } else if (impl->h_array_owned) {
86     impl->h_array = impl->h_array_owned;
87   } else {
88     CeedInt length;
89     ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
90     ierr = CeedCalloc(length, &impl->h_array_owned); CeedChkBackend(ierr);
91     impl->h_array = impl->h_array_owned;
92   }
93 
94   ierr = hipMemcpy(impl->h_array, impl->d_array, bytes(vec),
95                    hipMemcpyDeviceToHost); CeedChk_Hip(ceed, ierr);
96 
97   return CEED_ERROR_SUCCESS;
98 }
99 
100 //------------------------------------------------------------------------------
101 // Sync arrays
102 //------------------------------------------------------------------------------
103 static inline int CeedVectorSync_Hip(const CeedVector vec,
104                                      CeedMemType mem_type) {
105   switch (mem_type) {
106   case CEED_MEM_HOST: return CeedVectorSyncD2H_Hip(vec);
107   case CEED_MEM_DEVICE: return CeedVectorSyncH2D_Hip(vec);
108   }
109   return CEED_ERROR_UNSUPPORTED;
110 }
111 
112 //------------------------------------------------------------------------------
113 // Set all pointers as invalid
114 //------------------------------------------------------------------------------
115 static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) {
116   int ierr;
117   CeedVector_Hip *impl;
118   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
119 
120   impl->h_array = NULL;
121   impl->d_array = NULL;
122 
123   return CEED_ERROR_SUCCESS;
124 }
125 
126 //------------------------------------------------------------------------------
127 // Check if CeedVector has any valid pointers
128 //------------------------------------------------------------------------------
129 static inline int CeedVectorHasValidArray_Hip(const CeedVector vec,
130     bool *has_valid_array) {
131   int ierr;
132   CeedVector_Hip *impl;
133   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
134 
135   *has_valid_array = !!impl->h_array || !!impl->d_array;
136 
137   return CEED_ERROR_SUCCESS;
138 }
139 
140 //------------------------------------------------------------------------------
141 // Check if has any array of given type
142 //------------------------------------------------------------------------------
143 static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec,
144     CeedMemType mem_type, bool *has_array_of_type) {
145   int ierr;
146   CeedVector_Hip *impl;
147   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
148 
149   switch (mem_type) {
150   case CEED_MEM_HOST:
151     *has_array_of_type = !!impl->h_array_borrowed || !!impl->h_array_owned;
152     break;
153   case CEED_MEM_DEVICE:
154     *has_array_of_type = !!impl->d_array_borrowed || !!impl->d_array_owned;
155     break;
156   }
157 
158   return CEED_ERROR_SUCCESS;
159 }
160 
161 //------------------------------------------------------------------------------
162 // Check if has borrowed array of given type
163 //------------------------------------------------------------------------------
164 static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec,
165     CeedMemType mem_type, bool *has_borrowed_array_of_type) {
166   int ierr;
167   CeedVector_Hip *impl;
168   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
169 
170   switch (mem_type) {
171   case CEED_MEM_HOST:
172     *has_borrowed_array_of_type = !!impl->h_array_borrowed;
173     break;
174   case CEED_MEM_DEVICE:
175     *has_borrowed_array_of_type = !!impl->d_array_borrowed;
176     break;
177   }
178 
179   return CEED_ERROR_SUCCESS;
180 }
181 
182 //------------------------------------------------------------------------------
183 // Sync array of given type
184 //------------------------------------------------------------------------------
185 static inline int CeedVectorNeedSync_Hip(const CeedVector vec,
186     CeedMemType mem_type, bool *need_sync) {
187   int ierr;
188   CeedVector_Hip *impl;
189   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
190 
191   bool has_valid_array = false;
192   ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr);
193   switch (mem_type) {
194   case CEED_MEM_HOST:
195     *need_sync = has_valid_array && !impl->h_array;
196     break;
197   case CEED_MEM_DEVICE:
198     *need_sync = has_valid_array && !impl->d_array;
199     break;
200   }
201 
202   return CEED_ERROR_SUCCESS;
203 }
204 
205 //------------------------------------------------------------------------------
206 // Set array from host
207 //------------------------------------------------------------------------------
208 static int CeedVectorSetArrayHost_Hip(const CeedVector vec,
209                                       const CeedCopyMode copy_mode, CeedScalar *array) {
210   int ierr;
211   CeedVector_Hip *impl;
212   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
213 
214   switch (copy_mode) {
215   case CEED_COPY_VALUES: {
216     CeedInt length;
217     if (!impl->h_array_owned) {
218       ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
219       ierr = CeedMalloc(length, &impl->h_array_owned); CeedChkBackend(ierr);
220     }
221     impl->h_array_borrowed = NULL;
222     impl->h_array = impl->h_array_owned;
223     if (array)
224       memcpy(impl->h_array, array, bytes(vec));
225   } break;
226   case CEED_OWN_POINTER:
227     ierr = CeedFree(&impl->h_array_owned); CeedChkBackend(ierr);
228     impl->h_array_owned = array;
229     impl->h_array_borrowed = NULL;
230     impl->h_array = array;
231     break;
232   case CEED_USE_POINTER:
233     ierr = CeedFree(&impl->h_array_owned); CeedChkBackend(ierr);
234     impl->h_array_borrowed = array;
235     impl->h_array = array;
236     break;
237   }
238 
239   return CEED_ERROR_SUCCESS;
240 }
241 
242 //------------------------------------------------------------------------------
243 // Set array from device
244 //------------------------------------------------------------------------------
245 static int CeedVectorSetArrayDevice_Hip(const CeedVector vec,
246                                         const CeedCopyMode copy_mode, CeedScalar *array) {
247   int ierr;
248   Ceed ceed;
249   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
250   CeedVector_Hip *impl;
251   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
252 
253   switch (copy_mode) {
254   case CEED_COPY_VALUES:
255     if (!impl->d_array_owned) {
256       ierr = hipMalloc((void **)&impl->d_array_owned, bytes(vec));
257       CeedChk_Hip(ceed, ierr);
258     }
259     impl->d_array_borrowed = NULL;
260     impl->d_array = impl->d_array_owned;
261     if (array) {
262       ierr = hipMemcpy(impl->d_array, array, bytes(vec),
263                        hipMemcpyDeviceToDevice); CeedChk_Hip(ceed, ierr);
264     }
265     break;
266   case CEED_OWN_POINTER:
267     ierr = hipFree(impl->d_array_owned); CeedChk_Hip(ceed, ierr);
268     impl->d_array_owned = array;
269     impl->d_array_borrowed = NULL;
270     impl->d_array = array;
271     break;
272   case CEED_USE_POINTER:
273     ierr = hipFree(impl->d_array_owned); CeedChk_Hip(ceed, ierr);
274     impl->d_array_owned = NULL;
275     impl->d_array_borrowed = array;
276     impl->d_array = array;
277     break;
278   }
279 
280   return CEED_ERROR_SUCCESS;
281 }
282 
283 //------------------------------------------------------------------------------
284 // Set the array used by a vector,
285 //   freeing any previously allocated array if applicable
286 //------------------------------------------------------------------------------
287 static int CeedVectorSetArray_Hip(const CeedVector vec,
288                                   const CeedMemType mem_type,
289                                   const CeedCopyMode copy_mode, CeedScalar *array) {
290   int ierr;
291   Ceed ceed;
292   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
293   CeedVector_Hip *impl;
294   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
295 
296   ierr = CeedVectorSetAllInvalid_Hip(vec); CeedChkBackend(ierr);
297   switch (mem_type) {
298   case CEED_MEM_HOST:
299     return CeedVectorSetArrayHost_Hip(vec, copy_mode, array);
300   case CEED_MEM_DEVICE:
301     return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array);
302   }
303 
304   return CEED_ERROR_UNSUPPORTED;
305 }
306 
307 //------------------------------------------------------------------------------
308 // Set host array to value
309 //------------------------------------------------------------------------------
310 static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedInt length,
311                                 CeedScalar val) {
312   for (int i = 0; i < length; i++)
313     h_array[i] = val;
314   return CEED_ERROR_SUCCESS;
315 }
316 
317 //------------------------------------------------------------------------------
318 // Set device array to value (impl in .hip file)
319 //------------------------------------------------------------------------------
320 int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedInt length, CeedScalar val);
321 
322 //------------------------------------------------------------------------------
323 // Set a vector to a value,
324 //------------------------------------------------------------------------------
325 static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
326   int ierr;
327   Ceed ceed;
328   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
329   CeedVector_Hip *impl;
330   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
331   CeedInt length;
332   ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
333 
334   // Set value for synced device/host array
335   if (!impl->d_array && !impl->h_array) {
336     if (impl->d_array_borrowed) {
337       impl->d_array = impl->d_array_borrowed;
338     } else if (impl->h_array_borrowed) {
339       impl->h_array = impl->h_array_borrowed;
340     } else if (impl->d_array_owned) {
341       impl->d_array = impl->d_array_owned;
342     } else if (impl->h_array_owned) {
343       impl->h_array = impl->h_array_owned;
344     } else {
345       ierr = CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL);
346       CeedChkBackend(ierr);
347     }
348   }
349   if (impl->d_array) {
350     ierr = CeedDeviceSetValue_Hip(impl->d_array, length, val); CeedChkBackend(ierr);
351   }
352   if (impl->h_array) {
353     ierr = CeedHostSetValue_Hip(impl->h_array, length, val); CeedChkBackend(ierr);
354   }
355 
356   return CEED_ERROR_SUCCESS;
357 }
358 
359 //------------------------------------------------------------------------------
360 // Vector Take Array
361 //------------------------------------------------------------------------------
362 static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type,
363                                    CeedScalar **array) {
364   int ierr;
365   Ceed ceed;
366   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
367   CeedVector_Hip *impl;
368   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
369 
370   // Sync array to requested mem_type
371   bool need_sync = false;
372   ierr = CeedVectorNeedSync_Hip(vec, mem_type, &need_sync); CeedChkBackend(ierr);
373   if (need_sync) {
374     ierr = CeedVectorSync_Hip(vec, mem_type); CeedChkBackend(ierr);
375   }
376 
377   // Update pointer
378   switch (mem_type) {
379   case CEED_MEM_HOST:
380     (*array) = impl->h_array_borrowed;
381     impl->h_array_borrowed = NULL;
382     impl->h_array = NULL;
383     break;
384   case CEED_MEM_DEVICE:
385     (*array) = impl->d_array_borrowed;
386     impl->d_array_borrowed = NULL;
387     impl->d_array = NULL;
388     break;
389   }
390 
391   return CEED_ERROR_SUCCESS;
392 }
393 
394 //------------------------------------------------------------------------------
395 // Core logic for array syncronization for GetArray.
396 //   If a different memory type is most up to date, this will perform a copy
397 //------------------------------------------------------------------------------
398 static int CeedVectorGetArrayCore_Hip(const CeedVector vec,
399                                       const CeedMemType mem_type, CeedScalar **array) {
400   int ierr;
401   Ceed ceed;
402   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
403   CeedVector_Hip *impl;
404   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
405 
406   bool need_sync = false;
407   ierr = CeedVectorNeedSync_Hip(vec, mem_type, &need_sync); CeedChkBackend(ierr);
408   CeedChkBackend(ierr);
409   if (need_sync) {
410     // Sync array to requested mem_type
411     ierr = CeedVectorSync_Hip(vec, mem_type); CeedChkBackend(ierr);
412   }
413 
414   // Update pointer
415   switch (mem_type) {
416   case CEED_MEM_HOST:
417     *array = impl->h_array;
418     break;
419   case CEED_MEM_DEVICE:
420     *array = impl->d_array;
421     break;
422   }
423 
424   return CEED_ERROR_SUCCESS;
425 }
426 
427 //------------------------------------------------------------------------------
428 // Get read-only access to a vector via the specified mem_type
429 //------------------------------------------------------------------------------
430 static int CeedVectorGetArrayRead_Hip(const CeedVector vec,
431                                       const CeedMemType mem_type, const CeedScalar **array) {
432   return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array);
433 }
434 
435 //------------------------------------------------------------------------------
436 // Get read/write access to a vector via the specified mem_type
437 //------------------------------------------------------------------------------
438 static int CeedVectorGetArray_Hip(const CeedVector vec,
439                                   const CeedMemType mem_type,
440                                   CeedScalar **array) {
441   int ierr;
442   CeedVector_Hip *impl;
443   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
444 
445   ierr = CeedVectorGetArrayCore_Hip(vec, mem_type, array); CeedChkBackend(ierr);
446 
447   ierr = CeedVectorSetAllInvalid_Hip(vec); CeedChkBackend(ierr);
448   switch (mem_type) {
449   case CEED_MEM_HOST:
450     impl->h_array = *array;
451     break;
452   case CEED_MEM_DEVICE:
453     impl->d_array = *array;
454     break;
455   }
456 
457   return CEED_ERROR_SUCCESS;
458 }
459 
460 //------------------------------------------------------------------------------
461 // Get write access to a vector via the specified mem_type
462 //------------------------------------------------------------------------------
463 static int CeedVectorGetArrayWrite_Hip(const CeedVector vec,
464                                        const CeedMemType mem_type, CeedScalar **array) {
465   int ierr;
466   CeedVector_Hip *impl;
467   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
468 
469   bool has_array_of_type = true;
470   ierr = CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type);
471   CeedChkBackend(ierr);
472   if (!has_array_of_type) {
473     // Allocate if array is not yet allocated
474     ierr = CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL);
475     CeedChkBackend(ierr);
476   } else {
477     // Select dirty array
478     switch (mem_type) {
479     case CEED_MEM_HOST:
480       if (impl->h_array_borrowed)
481         impl->h_array = impl->h_array_borrowed;
482       else
483         impl->h_array = impl->h_array_owned;
484       break;
485     case CEED_MEM_DEVICE:
486       if (impl->d_array_borrowed)
487         impl->d_array = impl->d_array_borrowed;
488       else
489         impl->d_array = impl->d_array_owned;
490     }
491   }
492 
493   return CeedVectorGetArray_Hip(vec, mem_type, array);
494 }
495 
496 //------------------------------------------------------------------------------
497 // Get the norm of a CeedVector
498 //------------------------------------------------------------------------------
499 static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type,
500                               CeedScalar *norm) {
501   int ierr;
502   Ceed ceed;
503   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
504   CeedVector_Hip *impl;
505   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
506   CeedInt length;
507   ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
508   hipblasHandle_t handle;
509   ierr = CeedHipGetHipblasHandle(ceed, &handle); CeedChkBackend(ierr);
510 
511   // Compute norm
512   const CeedScalar *d_array;
513   ierr = CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array);
514   CeedChkBackend(ierr);
515   switch (type) {
516   case CEED_NORM_1: {
517     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
518       ierr = hipblasSasum(handle, length, (float *) d_array, 1, (float *) norm);
519     } else {
520       ierr = hipblasDasum(handle, length, (double *) d_array, 1, (double *) norm);
521     }
522     CeedChk_Hipblas(ceed, ierr);
523     break;
524   }
525   case CEED_NORM_2: {
526     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
527       ierr = hipblasSnrm2(handle, length, (float *) d_array, 1, (float *) norm);
528     } else {
529       ierr = hipblasDnrm2(handle, length, (double *) d_array, 1, (double *) norm);
530     }
531     CeedChk_Hipblas(ceed, ierr);
532     break;
533   }
534   case CEED_NORM_MAX: {
535     CeedInt indx;
536     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
537       ierr = hipblasIsamax(handle, length, (float *) d_array, 1, &indx);
538     } else {
539       ierr = hipblasIdamax(handle, length, (double *) d_array, 1, &indx);
540     }
541     CeedChk_Hipblas(ceed, ierr);
542     CeedScalar normNoAbs;
543     ierr = hipMemcpy(&normNoAbs, impl->d_array+indx-1, sizeof(CeedScalar),
544                      hipMemcpyDeviceToHost); CeedChk_Hip(ceed, ierr);
545     *norm = fabs(normNoAbs);
546     break;
547   }
548   }
549   ierr = CeedVectorRestoreArrayRead(vec, &d_array); CeedChkBackend(ierr);
550 
551   return CEED_ERROR_SUCCESS;
552 }
553 
554 //------------------------------------------------------------------------------
555 // Take reciprocal of a vector on host
556 //------------------------------------------------------------------------------
557 static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedInt length) {
558   for (int i = 0; i < length; i++)
559     if (fabs(h_array[i]) > CEED_EPSILON)
560       h_array[i] = 1./h_array[i];
561   return CEED_ERROR_SUCCESS;
562 }
563 
564 //------------------------------------------------------------------------------
565 // Take reciprocal of a vector on device (impl in .cu file)
566 //------------------------------------------------------------------------------
567 int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedInt length);
568 
569 //------------------------------------------------------------------------------
570 // Take reciprocal of a vector
571 //------------------------------------------------------------------------------
572 static int CeedVectorReciprocal_Hip(CeedVector vec) {
573   int ierr;
574   Ceed ceed;
575   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
576   CeedVector_Hip *impl;
577   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
578   CeedInt length;
579   ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
580 
581   // Set value for synced device/host array
582   if (impl->d_array) {
583     ierr = CeedDeviceReciprocal_Hip(impl->d_array, length); CeedChkBackend(ierr);
584   }
585   if (impl->h_array) {
586     ierr = CeedHostReciprocal_Hip(impl->h_array, length); CeedChkBackend(ierr);
587   }
588 
589   return CEED_ERROR_SUCCESS;
590 }
591 
592 //------------------------------------------------------------------------------
593 // Compute x = alpha x on the host
594 //------------------------------------------------------------------------------
595 static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha,
596                              CeedInt length) {
597   for (int i = 0; i < length; i++)
598     x_array[i] *= alpha;
599   return CEED_ERROR_SUCCESS;
600 }
601 
602 //------------------------------------------------------------------------------
603 // Compute x = alpha x on device (impl in .cu file)
604 //------------------------------------------------------------------------------
605 int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha,
606                         CeedInt length);
607 
608 //------------------------------------------------------------------------------
609 // Compute x = alpha x
610 //------------------------------------------------------------------------------
611 static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
612   int ierr;
613   Ceed ceed;
614   ierr = CeedVectorGetCeed(x, &ceed); CeedChkBackend(ierr);
615   CeedVector_Hip *x_impl;
616   ierr = CeedVectorGetData(x, &x_impl); CeedChkBackend(ierr);
617   CeedInt length;
618   ierr = CeedVectorGetLength(x, &length); CeedChkBackend(ierr);
619 
620   // Set value for synced device/host array
621   if (x_impl->d_array) {
622     ierr = CeedDeviceScale_Hip(x_impl->d_array, alpha, length);
623     CeedChkBackend(ierr);
624   }
625   if (x_impl->h_array) {
626     ierr = CeedHostScale_Hip(x_impl->h_array, alpha, length); CeedChkBackend(ierr);
627   }
628 
629   return CEED_ERROR_SUCCESS;
630 }
631 
632 //------------------------------------------------------------------------------
633 // Compute y = alpha x + y on the host
634 //------------------------------------------------------------------------------
635 static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha,
636                             CeedScalar *x_array, CeedInt length) {
637   for (int i = 0; i < length; i++)
638     y_array[i] += alpha * x_array[i];
639   return CEED_ERROR_SUCCESS;
640 }
641 
642 //------------------------------------------------------------------------------
643 // Compute y = alpha x + y on device (impl in .cu file)
644 //------------------------------------------------------------------------------
645 int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha,
646                        CeedScalar *x_array, CeedInt length);
647 
648 //------------------------------------------------------------------------------
649 // Compute y = alpha x + y
650 //------------------------------------------------------------------------------
651 static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
652   int ierr;
653   Ceed ceed;
654   ierr = CeedVectorGetCeed(y, &ceed); CeedChkBackend(ierr);
655   CeedVector_Hip *y_impl, *x_impl;
656   ierr = CeedVectorGetData(y, &y_impl); CeedChkBackend(ierr);
657   ierr = CeedVectorGetData(x, &x_impl); CeedChkBackend(ierr);
658   CeedInt length;
659   ierr = CeedVectorGetLength(y, &length); CeedChkBackend(ierr);
660 
661   // Set value for synced device/host array
662   if (y_impl->d_array) {
663     ierr = CeedVectorSyncArray(x, CEED_MEM_DEVICE); CeedChkBackend(ierr);
664     ierr = CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length);
665     CeedChkBackend(ierr);
666   }
667   if (y_impl->h_array) {
668     ierr = CeedVectorSyncArray(x, CEED_MEM_HOST); CeedChkBackend(ierr);
669     ierr = CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length);
670     CeedChkBackend(ierr);
671   }
672 
673   return CEED_ERROR_SUCCESS;
674 }
675 
676 //------------------------------------------------------------------------------
677 // Compute the pointwise multiplication w = x .* y on the host
678 //------------------------------------------------------------------------------
679 static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array,
680                                      CeedScalar *y_array, CeedInt length) {
681   for (int i = 0; i < length; i++)
682     w_array[i] = x_array[i] * y_array[i];
683   return CEED_ERROR_SUCCESS;
684 }
685 
686 //------------------------------------------------------------------------------
687 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
688 //------------------------------------------------------------------------------
689 int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array,
690                                 CeedScalar *y_array, CeedInt length);
691 
692 //------------------------------------------------------------------------------
693 // Compute the pointwise multiplication w = x .* y
694 //------------------------------------------------------------------------------
695 static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x,
696                                        CeedVector y) {
697   int ierr;
698   Ceed ceed;
699   ierr = CeedVectorGetCeed(w, &ceed); CeedChkBackend(ierr);
700   CeedVector_Hip *w_impl, *x_impl, *y_impl;
701   ierr = CeedVectorGetData(w, &w_impl); CeedChkBackend(ierr);
702   ierr = CeedVectorGetData(x, &x_impl); CeedChkBackend(ierr);
703   ierr = CeedVectorGetData(y, &y_impl); CeedChkBackend(ierr);
704   CeedInt length;
705   ierr = CeedVectorGetLength(w, &length); CeedChkBackend(ierr);
706 
707   // Set value for synced device/host array
708   if (!w_impl->d_array && !w_impl->h_array) {
709     ierr = CeedVectorSetValue(w, 0.0); CeedChkBackend(ierr);
710   }
711   if (w_impl->d_array) {
712     ierr = CeedVectorSyncArray(x, CEED_MEM_DEVICE); CeedChkBackend(ierr);
713     ierr = CeedVectorSyncArray(y, CEED_MEM_DEVICE); CeedChkBackend(ierr);
714     ierr = CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array,
715                                        y_impl->d_array, length);
716     CeedChkBackend(ierr);
717   }
718   if (w_impl->h_array) {
719     ierr = CeedVectorSyncArray(x, CEED_MEM_HOST); CeedChkBackend(ierr);
720     ierr = CeedVectorSyncArray(y, CEED_MEM_HOST); CeedChkBackend(ierr);
721     ierr = CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array,
722                                      y_impl->h_array, length);
723     CeedChkBackend(ierr);
724   }
725 
726   return CEED_ERROR_SUCCESS;
727 }
728 
729 //------------------------------------------------------------------------------
730 // Destroy the vector
731 //------------------------------------------------------------------------------
732 static int CeedVectorDestroy_Hip(const CeedVector vec) {
733   int ierr;
734   Ceed ceed;
735   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
736   CeedVector_Hip *impl;
737   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
738 
739   ierr = hipFree(impl->d_array_owned); CeedChk_Hip(ceed, ierr);
740   ierr = CeedFree(&impl->h_array_owned); CeedChkBackend(ierr);
741   ierr = CeedFree(&impl); CeedChkBackend(ierr);
742 
743   return CEED_ERROR_SUCCESS;
744 }
745 
746 //------------------------------------------------------------------------------
747 // Create a vector of the specified length (does not allocate memory)
748 //------------------------------------------------------------------------------
749 int CeedVectorCreate_Hip(CeedInt n, CeedVector vec) {
750   CeedVector_Hip *impl;
751   int ierr;
752   Ceed ceed;
753   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
754 
755   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray",
756                                 CeedVectorHasValidArray_Hip); CeedChkBackend(ierr);
757   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType",
758                                 CeedVectorHasBorrowedArrayOfType_Hip);
759   CeedChkBackend(ierr);
760   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SetArray",
761                                 CeedVectorSetArray_Hip); CeedChkBackend(ierr);
762   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray",
763                                 CeedVectorTakeArray_Hip); CeedChkBackend(ierr);
764   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SetValue",
765                                 (int (*)())(CeedVectorSetValue_Hip)); CeedChkBackend(ierr);
766   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArray",
767                                 CeedVectorGetArray_Hip); CeedChkBackend(ierr);
768   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead",
769                                 CeedVectorGetArrayRead_Hip); CeedChkBackend(ierr);
770   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite",
771                                 CeedVectorGetArrayWrite_Hip); CeedChkBackend(ierr);
772   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "Norm",
773                                 CeedVectorNorm_Hip); CeedChkBackend(ierr);
774   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal",
775                                 CeedVectorReciprocal_Hip); CeedChkBackend(ierr);
776   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "Scale",
777                                 (int (*)())(CeedVectorScale_Hip)); CeedChkBackend(ierr);
778   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "AXPY",
779                                 (int (*)())(CeedVectorAXPY_Hip)); CeedChkBackend(ierr);
780   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult",
781                                 CeedVectorPointwiseMult_Hip); CeedChkBackend(ierr);
782   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "Destroy",
783                                 CeedVectorDestroy_Hip); CeedChkBackend(ierr);
784 
785   ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr);
786   ierr = CeedVectorSetData(vec, impl); CeedChkBackend(ierr);
787 
788   return CEED_ERROR_SUCCESS;
789 }
790