xref: /libCEED/backends/hip-ref/ceed-hip-ref-vector.c (revision dc64899e51135d10dd401f305b8beed8356c525d)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/ceed.h>
9 #include <ceed/backend.h>
10 #include <hip/hip_runtime.h>
11 #include <math.h>
12 #include <string.h>
13 #include "ceed-hip-ref.h"
14 
15 
16 //------------------------------------------------------------------------------
17 // Check if host/device sync is needed
18 //------------------------------------------------------------------------------
19 static inline int CeedVectorNeedSync_Hip(const CeedVector vec,
20     CeedMemType mem_type, bool *need_sync) {
21   int ierr;
22   CeedVector_Hip *impl;
23   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
24 
25   bool has_valid_array = false;
26   ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr);
27   switch (mem_type) {
28   case CEED_MEM_HOST:
29     *need_sync = has_valid_array && !impl->h_array;
30     break;
31   case CEED_MEM_DEVICE:
32     *need_sync = has_valid_array && !impl->d_array;
33     break;
34   }
35 
36   return CEED_ERROR_SUCCESS;
37 }
38 
39 //------------------------------------------------------------------------------
40 // Sync host to device
41 //------------------------------------------------------------------------------
42 static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) {
43   int ierr;
44   Ceed ceed;
45   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
46   CeedVector_Hip *impl;
47   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
48 
49   CeedSize length;
50   ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
51   size_t bytes = length * sizeof(CeedScalar);
52 
53   if (!impl->h_array)
54     // LCOV_EXCL_START
55     return CeedError(ceed, CEED_ERROR_BACKEND,
56                      "No valid host data to sync to device");
57   // LCOV_EXCL_STOP
58 
59   if (impl->d_array_borrowed) {
60     impl->d_array = impl->d_array_borrowed;
61   } else if (impl->d_array_owned) {
62     impl->d_array = impl->d_array_owned;
63   } else {
64     ierr = hipMalloc((void **)&impl->d_array_owned, bytes);
65     CeedChk_Hip(ceed, ierr);
66     impl->d_array = impl->d_array_owned;
67   }
68 
69   ierr = hipMemcpy(impl->d_array, impl->h_array, bytes,
70                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
71 
72   return CEED_ERROR_SUCCESS;
73 }
74 
75 //------------------------------------------------------------------------------
76 // Sync device to host
77 //------------------------------------------------------------------------------
78 static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
79   int ierr;
80   Ceed ceed;
81   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
82   CeedVector_Hip *impl;
83   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
84 
85   if (!impl->d_array)
86     // LCOV_EXCL_START
87     return CeedError(ceed, CEED_ERROR_BACKEND,
88                      "No valid device data to sync to host");
89   // LCOV_EXCL_STOP
90 
91   if (impl->h_array_borrowed) {
92     impl->h_array = impl->h_array_borrowed;
93   } else if (impl->h_array_owned) {
94     impl->h_array = impl->h_array_owned;
95   } else {
96     CeedSize length;
97     ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
98     ierr = CeedCalloc(length, &impl->h_array_owned); CeedChkBackend(ierr);
99     impl->h_array = impl->h_array_owned;
100   }
101 
102   CeedSize length;
103   ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
104   size_t bytes = length * sizeof(CeedScalar);
105   ierr = hipMemcpy(impl->h_array, impl->d_array, bytes,
106                    hipMemcpyDeviceToHost); CeedChk_Hip(ceed, ierr);
107 
108   return CEED_ERROR_SUCCESS;
109 }
110 
111 //------------------------------------------------------------------------------
112 // Sync arrays
113 //------------------------------------------------------------------------------
114 static int CeedVectorSyncArray_Hip(const CeedVector vec,
115                                    CeedMemType mem_type) {
116   int ierr;
117   // Check whether device/host sync is needed
118   bool need_sync = false;
119   ierr = CeedVectorNeedSync_Hip(vec, mem_type, &need_sync);
120   CeedChkBackend(ierr);
121   if (!need_sync)
122     return CEED_ERROR_SUCCESS;
123 
124   switch (mem_type) {
125   case CEED_MEM_HOST: return CeedVectorSyncD2H_Hip(vec);
126   case CEED_MEM_DEVICE: return CeedVectorSyncH2D_Hip(vec);
127   }
128   return CEED_ERROR_UNSUPPORTED;
129 }
130 
131 //------------------------------------------------------------------------------
132 // Set all pointers as invalid
133 //------------------------------------------------------------------------------
134 static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) {
135   int ierr;
136   CeedVector_Hip *impl;
137   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
138 
139   impl->h_array = NULL;
140   impl->d_array = NULL;
141 
142   return CEED_ERROR_SUCCESS;
143 }
144 
145 //------------------------------------------------------------------------------
146 // Check if CeedVector has any valid pointers
147 //------------------------------------------------------------------------------
148 static inline int CeedVectorHasValidArray_Hip(const CeedVector vec,
149     bool *has_valid_array) {
150   int ierr;
151   CeedVector_Hip *impl;
152   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
153 
154   *has_valid_array = !!impl->h_array || !!impl->d_array;
155 
156   return CEED_ERROR_SUCCESS;
157 }
158 
159 //------------------------------------------------------------------------------
160 // Check if has any array of given type
161 //------------------------------------------------------------------------------
162 static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec,
163     CeedMemType mem_type, bool *has_array_of_type) {
164   int ierr;
165   CeedVector_Hip *impl;
166   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
167 
168   switch (mem_type) {
169   case CEED_MEM_HOST:
170     *has_array_of_type = !!impl->h_array_borrowed || !!impl->h_array_owned;
171     break;
172   case CEED_MEM_DEVICE:
173     *has_array_of_type = !!impl->d_array_borrowed || !!impl->d_array_owned;
174     break;
175   }
176 
177   return CEED_ERROR_SUCCESS;
178 }
179 
180 //------------------------------------------------------------------------------
181 // Check if has borrowed array of given type
182 //------------------------------------------------------------------------------
183 static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec,
184     CeedMemType mem_type, bool *has_borrowed_array_of_type) {
185   int ierr;
186   CeedVector_Hip *impl;
187   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
188 
189   switch (mem_type) {
190   case CEED_MEM_HOST:
191     *has_borrowed_array_of_type = !!impl->h_array_borrowed;
192     break;
193   case CEED_MEM_DEVICE:
194     *has_borrowed_array_of_type = !!impl->d_array_borrowed;
195     break;
196   }
197 
198   return CEED_ERROR_SUCCESS;
199 }
200 
201 //------------------------------------------------------------------------------
202 // Set array from host
203 //------------------------------------------------------------------------------
204 static int CeedVectorSetArrayHost_Hip(const CeedVector vec,
205                                       const CeedCopyMode copy_mode, CeedScalar *array) {
206   int ierr;
207   CeedVector_Hip *impl;
208   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
209 
210   switch (copy_mode) {
211   case CEED_COPY_VALUES: {
212     CeedSize length;
213     if (!impl->h_array_owned) {
214       ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
215       ierr = CeedMalloc(length, &impl->h_array_owned); CeedChkBackend(ierr);
216     }
217     impl->h_array_borrowed = NULL;
218     impl->h_array = impl->h_array_owned;
219     if (array) {
220       CeedSize length;
221       ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
222       size_t bytes = length * sizeof(CeedScalar);
223       memcpy(impl->h_array, array, bytes);
224     }
225   } break;
226   case CEED_OWN_POINTER:
227     ierr = CeedFree(&impl->h_array_owned); CeedChkBackend(ierr);
228     impl->h_array_owned = array;
229     impl->h_array_borrowed = NULL;
230     impl->h_array = array;
231     break;
232   case CEED_USE_POINTER:
233     ierr = CeedFree(&impl->h_array_owned); CeedChkBackend(ierr);
234     impl->h_array_borrowed = array;
235     impl->h_array = array;
236     break;
237   }
238 
239   return CEED_ERROR_SUCCESS;
240 }
241 
242 //------------------------------------------------------------------------------
243 // Set array from device
244 //------------------------------------------------------------------------------
245 static int CeedVectorSetArrayDevice_Hip(const CeedVector vec,
246                                         const CeedCopyMode copy_mode, CeedScalar *array) {
247   int ierr;
248   Ceed ceed;
249   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
250   CeedVector_Hip *impl;
251   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
252 
253   switch (copy_mode) {
254   case CEED_COPY_VALUES: {
255     CeedSize length;
256     ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
257     size_t bytes = length * sizeof(CeedScalar);
258     if (!impl->d_array_owned) {
259       ierr = hipMalloc((void **)&impl->d_array_owned, bytes);
260       CeedChk_Hip(ceed, ierr);
261     }
262     impl->d_array_borrowed = NULL;
263     impl->d_array = impl->d_array_owned;
264     if (array) {
265       ierr = hipMemcpy(impl->d_array, array, bytes,
266                        hipMemcpyDeviceToDevice); CeedChk_Hip(ceed, ierr);
267     }
268   } break;
269   case CEED_OWN_POINTER:
270     ierr = hipFree(impl->d_array_owned); CeedChk_Hip(ceed, ierr);
271     impl->d_array_owned = array;
272     impl->d_array_borrowed = NULL;
273     impl->d_array = array;
274     break;
275   case CEED_USE_POINTER:
276     ierr = hipFree(impl->d_array_owned); CeedChk_Hip(ceed, ierr);
277     impl->d_array_owned = NULL;
278     impl->d_array_borrowed = array;
279     impl->d_array = array;
280     break;
281   }
282 
283   return CEED_ERROR_SUCCESS;
284 }
285 
286 //------------------------------------------------------------------------------
287 // Set the array used by a vector,
288 //   freeing any previously allocated array if applicable
289 //------------------------------------------------------------------------------
290 static int CeedVectorSetArray_Hip(const CeedVector vec,
291                                   const CeedMemType mem_type,
292                                   const CeedCopyMode copy_mode, CeedScalar *array) {
293   int ierr;
294   Ceed ceed;
295   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
296   CeedVector_Hip *impl;
297   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
298 
299   ierr = CeedVectorSetAllInvalid_Hip(vec); CeedChkBackend(ierr);
300   switch (mem_type) {
301   case CEED_MEM_HOST:
302     return CeedVectorSetArrayHost_Hip(vec, copy_mode, array);
303   case CEED_MEM_DEVICE:
304     return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array);
305   }
306 
307   return CEED_ERROR_UNSUPPORTED;
308 }
309 
310 //------------------------------------------------------------------------------
311 // Set host array to value
312 //------------------------------------------------------------------------------
313 static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedInt length,
314                                 CeedScalar val) {
315   for (int i = 0; i < length; i++)
316     h_array[i] = val;
317   return CEED_ERROR_SUCCESS;
318 }
319 
320 //------------------------------------------------------------------------------
321 // Set device array to value (impl in .hip file)
322 //------------------------------------------------------------------------------
323 int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedInt length, CeedScalar val);
324 
325 //------------------------------------------------------------------------------
326 // Set a vector to a value,
327 //------------------------------------------------------------------------------
328 static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
329   int ierr;
330   Ceed ceed;
331   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
332   CeedVector_Hip *impl;
333   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
334   CeedSize length;
335   ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
336 
337   // Set value for synced device/host array
338   if (!impl->d_array && !impl->h_array) {
339     if (impl->d_array_borrowed) {
340       impl->d_array = impl->d_array_borrowed;
341     } else if (impl->h_array_borrowed) {
342       impl->h_array = impl->h_array_borrowed;
343     } else if (impl->d_array_owned) {
344       impl->d_array = impl->d_array_owned;
345     } else if (impl->h_array_owned) {
346       impl->h_array = impl->h_array_owned;
347     } else {
348       ierr = CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL);
349       CeedChkBackend(ierr);
350     }
351   }
352   if (impl->d_array) {
353     ierr = CeedDeviceSetValue_Hip(impl->d_array, length, val); CeedChkBackend(ierr);
354   }
355   if (impl->h_array) {
356     ierr = CeedHostSetValue_Hip(impl->h_array, length, val); CeedChkBackend(ierr);
357   }
358 
359   return CEED_ERROR_SUCCESS;
360 }
361 
362 //------------------------------------------------------------------------------
363 // Vector Take Array
364 //------------------------------------------------------------------------------
365 static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type,
366                                    CeedScalar **array) {
367   int ierr;
368   Ceed ceed;
369   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
370   CeedVector_Hip *impl;
371   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
372 
373   // Sync array to requested mem_type
374   ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr);
375 
376   // Update pointer
377   switch (mem_type) {
378   case CEED_MEM_HOST:
379     (*array) = impl->h_array_borrowed;
380     impl->h_array_borrowed = NULL;
381     impl->h_array = NULL;
382     break;
383   case CEED_MEM_DEVICE:
384     (*array) = impl->d_array_borrowed;
385     impl->d_array_borrowed = NULL;
386     impl->d_array = NULL;
387     break;
388   }
389 
390   return CEED_ERROR_SUCCESS;
391 }
392 
393 //------------------------------------------------------------------------------
394 // Core logic for array syncronization for GetArray.
395 //   If a different memory type is most up to date, this will perform a copy
396 //------------------------------------------------------------------------------
397 static int CeedVectorGetArrayCore_Hip(const CeedVector vec,
398                                       const CeedMemType mem_type, CeedScalar **array) {
399   int ierr;
400   Ceed ceed;
401   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
402   CeedVector_Hip *impl;
403   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
404 
405   // Sync array to requested mem_type
406   ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr);
407 
408   // Update pointer
409   switch (mem_type) {
410   case CEED_MEM_HOST:
411     *array = impl->h_array;
412     break;
413   case CEED_MEM_DEVICE:
414     *array = impl->d_array;
415     break;
416   }
417 
418   return CEED_ERROR_SUCCESS;
419 }
420 
421 //------------------------------------------------------------------------------
422 // Get read-only access to a vector via the specified mem_type
423 //------------------------------------------------------------------------------
424 static int CeedVectorGetArrayRead_Hip(const CeedVector vec,
425                                       const CeedMemType mem_type, const CeedScalar **array) {
426   return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array);
427 }
428 
429 //------------------------------------------------------------------------------
430 // Get read/write access to a vector via the specified mem_type
431 //------------------------------------------------------------------------------
432 static int CeedVectorGetArray_Hip(const CeedVector vec,
433                                   const CeedMemType mem_type,
434                                   CeedScalar **array) {
435   int ierr;
436   CeedVector_Hip *impl;
437   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
438 
439   ierr = CeedVectorGetArrayCore_Hip(vec, mem_type, array); CeedChkBackend(ierr);
440 
441   ierr = CeedVectorSetAllInvalid_Hip(vec); CeedChkBackend(ierr);
442   switch (mem_type) {
443   case CEED_MEM_HOST:
444     impl->h_array = *array;
445     break;
446   case CEED_MEM_DEVICE:
447     impl->d_array = *array;
448     break;
449   }
450 
451   return CEED_ERROR_SUCCESS;
452 }
453 
454 //------------------------------------------------------------------------------
455 // Get write access to a vector via the specified mem_type
456 //------------------------------------------------------------------------------
457 static int CeedVectorGetArrayWrite_Hip(const CeedVector vec,
458                                        const CeedMemType mem_type, CeedScalar **array) {
459   int ierr;
460   CeedVector_Hip *impl;
461   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
462 
463   bool has_array_of_type = true;
464   ierr = CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type);
465   CeedChkBackend(ierr);
466   if (!has_array_of_type) {
467     // Allocate if array is not yet allocated
468     ierr = CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL);
469     CeedChkBackend(ierr);
470   } else {
471     // Select dirty array
472     switch (mem_type) {
473     case CEED_MEM_HOST:
474       if (impl->h_array_borrowed)
475         impl->h_array = impl->h_array_borrowed;
476       else
477         impl->h_array = impl->h_array_owned;
478       break;
479     case CEED_MEM_DEVICE:
480       if (impl->d_array_borrowed)
481         impl->d_array = impl->d_array_borrowed;
482       else
483         impl->d_array = impl->d_array_owned;
484     }
485   }
486 
487   return CeedVectorGetArray_Hip(vec, mem_type, array);
488 }
489 
490 //------------------------------------------------------------------------------
491 // Get the norm of a CeedVector
492 //------------------------------------------------------------------------------
493 static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type,
494                               CeedScalar *norm) {
495   int ierr;
496   Ceed ceed;
497   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
498   CeedVector_Hip *impl;
499   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
500   CeedSize length;
501   ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
502   hipblasHandle_t handle;
503   ierr = CeedHipGetHipblasHandle(ceed, &handle); CeedChkBackend(ierr);
504 
505   // Compute norm
506   const CeedScalar *d_array;
507   ierr = CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array);
508   CeedChkBackend(ierr);
509   switch (type) {
510   case CEED_NORM_1: {
511     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
512       ierr = hipblasSasum(handle, length, (float *) d_array, 1, (float *) norm);
513     } else {
514       ierr = hipblasDasum(handle, length, (double *) d_array, 1, (double *) norm);
515     }
516     CeedChk_Hipblas(ceed, ierr);
517     break;
518   }
519   case CEED_NORM_2: {
520     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
521       ierr = hipblasSnrm2(handle, length, (float *) d_array, 1, (float *) norm);
522     } else {
523       ierr = hipblasDnrm2(handle, length, (double *) d_array, 1, (double *) norm);
524     }
525     CeedChk_Hipblas(ceed, ierr);
526     break;
527   }
528   case CEED_NORM_MAX: {
529     CeedInt indx;
530     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
531       ierr = hipblasIsamax(handle, length, (float *) d_array, 1, &indx);
532     } else {
533       ierr = hipblasIdamax(handle, length, (double *) d_array, 1, &indx);
534     }
535     CeedChk_Hipblas(ceed, ierr);
536     CeedScalar normNoAbs;
537     ierr = hipMemcpy(&normNoAbs, impl->d_array+indx-1, sizeof(CeedScalar),
538                      hipMemcpyDeviceToHost); CeedChk_Hip(ceed, ierr);
539     *norm = fabs(normNoAbs);
540     break;
541   }
542   }
543   ierr = CeedVectorRestoreArrayRead(vec, &d_array); CeedChkBackend(ierr);
544 
545   return CEED_ERROR_SUCCESS;
546 }
547 
548 //------------------------------------------------------------------------------
549 // Take reciprocal of a vector on host
550 //------------------------------------------------------------------------------
551 static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedInt length) {
552   for (int i = 0; i < length; i++)
553     if (fabs(h_array[i]) > CEED_EPSILON)
554       h_array[i] = 1./h_array[i];
555   return CEED_ERROR_SUCCESS;
556 }
557 
558 //------------------------------------------------------------------------------
559 // Take reciprocal of a vector on device (impl in .cu file)
560 //------------------------------------------------------------------------------
561 int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedInt length);
562 
563 //------------------------------------------------------------------------------
564 // Take reciprocal of a vector
565 //------------------------------------------------------------------------------
566 static int CeedVectorReciprocal_Hip(CeedVector vec) {
567   int ierr;
568   Ceed ceed;
569   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
570   CeedVector_Hip *impl;
571   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
572   CeedSize length;
573   ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
574 
575   // Set value for synced device/host array
576   if (impl->d_array) {
577     ierr = CeedDeviceReciprocal_Hip(impl->d_array, length); CeedChkBackend(ierr);
578   }
579   if (impl->h_array) {
580     ierr = CeedHostReciprocal_Hip(impl->h_array, length); CeedChkBackend(ierr);
581   }
582 
583   return CEED_ERROR_SUCCESS;
584 }
585 
586 //------------------------------------------------------------------------------
587 // Compute x = alpha x on the host
588 //------------------------------------------------------------------------------
589 static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha,
590                              CeedInt length) {
591   for (int i = 0; i < length; i++)
592     x_array[i] *= alpha;
593   return CEED_ERROR_SUCCESS;
594 }
595 
596 //------------------------------------------------------------------------------
597 // Compute x = alpha x on device (impl in .cu file)
598 //------------------------------------------------------------------------------
599 int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha,
600                         CeedInt length);
601 
602 //------------------------------------------------------------------------------
603 // Compute x = alpha x
604 //------------------------------------------------------------------------------
605 static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
606   int ierr;
607   Ceed ceed;
608   ierr = CeedVectorGetCeed(x, &ceed); CeedChkBackend(ierr);
609   CeedVector_Hip *x_impl;
610   ierr = CeedVectorGetData(x, &x_impl); CeedChkBackend(ierr);
611   CeedSize length;
612   ierr = CeedVectorGetLength(x, &length); CeedChkBackend(ierr);
613 
614   // Set value for synced device/host array
615   if (x_impl->d_array) {
616     ierr = CeedDeviceScale_Hip(x_impl->d_array, alpha, length);
617     CeedChkBackend(ierr);
618   }
619   if (x_impl->h_array) {
620     ierr = CeedHostScale_Hip(x_impl->h_array, alpha, length); CeedChkBackend(ierr);
621   }
622 
623   return CEED_ERROR_SUCCESS;
624 }
625 
626 //------------------------------------------------------------------------------
627 // Compute y = alpha x + y on the host
628 //------------------------------------------------------------------------------
629 static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha,
630                             CeedScalar *x_array, CeedInt length) {
631   for (int i = 0; i < length; i++)
632     y_array[i] += alpha * x_array[i];
633   return CEED_ERROR_SUCCESS;
634 }
635 
636 //------------------------------------------------------------------------------
637 // Compute y = alpha x + y on device (impl in .cu file)
638 //------------------------------------------------------------------------------
639 int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha,
640                        CeedScalar *x_array, CeedInt length);
641 
642 //------------------------------------------------------------------------------
643 // Compute y = alpha x + y
644 //------------------------------------------------------------------------------
645 static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
646   int ierr;
647   Ceed ceed;
648   ierr = CeedVectorGetCeed(y, &ceed); CeedChkBackend(ierr);
649   CeedVector_Hip *y_impl, *x_impl;
650   ierr = CeedVectorGetData(y, &y_impl); CeedChkBackend(ierr);
651   ierr = CeedVectorGetData(x, &x_impl); CeedChkBackend(ierr);
652   CeedSize length;
653   ierr = CeedVectorGetLength(y, &length); CeedChkBackend(ierr);
654 
655   // Set value for synced device/host array
656   if (y_impl->d_array) {
657     ierr = CeedVectorSyncArray(x, CEED_MEM_DEVICE); CeedChkBackend(ierr);
658     ierr = CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length);
659     CeedChkBackend(ierr);
660   }
661   if (y_impl->h_array) {
662     ierr = CeedVectorSyncArray(x, CEED_MEM_HOST); CeedChkBackend(ierr);
663     ierr = CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length);
664     CeedChkBackend(ierr);
665   }
666 
667   return CEED_ERROR_SUCCESS;
668 }
669 
670 //------------------------------------------------------------------------------
671 // Compute the pointwise multiplication w = x .* y on the host
672 //------------------------------------------------------------------------------
673 static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array,
674                                      CeedScalar *y_array, CeedInt length) {
675   for (int i = 0; i < length; i++)
676     w_array[i] = x_array[i] * y_array[i];
677   return CEED_ERROR_SUCCESS;
678 }
679 
680 //------------------------------------------------------------------------------
681 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
682 //------------------------------------------------------------------------------
683 int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array,
684                                 CeedScalar *y_array, CeedInt length);
685 
686 //------------------------------------------------------------------------------
687 // Compute the pointwise multiplication w = x .* y
688 //------------------------------------------------------------------------------
689 static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x,
690                                        CeedVector y) {
691   int ierr;
692   Ceed ceed;
693   ierr = CeedVectorGetCeed(w, &ceed); CeedChkBackend(ierr);
694   CeedVector_Hip *w_impl, *x_impl, *y_impl;
695   ierr = CeedVectorGetData(w, &w_impl); CeedChkBackend(ierr);
696   ierr = CeedVectorGetData(x, &x_impl); CeedChkBackend(ierr);
697   ierr = CeedVectorGetData(y, &y_impl); CeedChkBackend(ierr);
698   CeedSize length;
699   ierr = CeedVectorGetLength(w, &length); CeedChkBackend(ierr);
700 
701   // Set value for synced device/host array
702   if (!w_impl->d_array && !w_impl->h_array) {
703     ierr = CeedVectorSetValue(w, 0.0); CeedChkBackend(ierr);
704   }
705   if (w_impl->d_array) {
706     ierr = CeedVectorSyncArray(x, CEED_MEM_DEVICE); CeedChkBackend(ierr);
707     ierr = CeedVectorSyncArray(y, CEED_MEM_DEVICE); CeedChkBackend(ierr);
708     ierr = CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array,
709                                        y_impl->d_array, length);
710     CeedChkBackend(ierr);
711   }
712   if (w_impl->h_array) {
713     ierr = CeedVectorSyncArray(x, CEED_MEM_HOST); CeedChkBackend(ierr);
714     ierr = CeedVectorSyncArray(y, CEED_MEM_HOST); CeedChkBackend(ierr);
715     ierr = CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array,
716                                      y_impl->h_array, length);
717     CeedChkBackend(ierr);
718   }
719 
720   return CEED_ERROR_SUCCESS;
721 }
722 
723 //------------------------------------------------------------------------------
724 // Destroy the vector
725 //------------------------------------------------------------------------------
726 static int CeedVectorDestroy_Hip(const CeedVector vec) {
727   int ierr;
728   Ceed ceed;
729   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
730   CeedVector_Hip *impl;
731   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
732 
733   ierr = hipFree(impl->d_array_owned); CeedChk_Hip(ceed, ierr);
734   ierr = CeedFree(&impl->h_array_owned); CeedChkBackend(ierr);
735   ierr = CeedFree(&impl); CeedChkBackend(ierr);
736 
737   return CEED_ERROR_SUCCESS;
738 }
739 
740 //------------------------------------------------------------------------------
741 // Create a vector of the specified length (does not allocate memory)
742 //------------------------------------------------------------------------------
743 int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
744   CeedVector_Hip *impl;
745   int ierr;
746   Ceed ceed;
747   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
748 
749   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray",
750                                 CeedVectorHasValidArray_Hip); CeedChkBackend(ierr);
751   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType",
752                                 CeedVectorHasBorrowedArrayOfType_Hip);
753   CeedChkBackend(ierr);
754   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SetArray",
755                                 CeedVectorSetArray_Hip); CeedChkBackend(ierr);
756   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray",
757                                 CeedVectorTakeArray_Hip); CeedChkBackend(ierr);
758   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SetValue",
759                                 (int (*)())(CeedVectorSetValue_Hip)); CeedChkBackend(ierr);
760   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray",
761                                 CeedVectorSyncArray_Hip); CeedChkBackend(ierr);
762   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArray",
763                                 CeedVectorGetArray_Hip); CeedChkBackend(ierr);
764   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead",
765                                 CeedVectorGetArrayRead_Hip); CeedChkBackend(ierr);
766   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite",
767                                 CeedVectorGetArrayWrite_Hip); CeedChkBackend(ierr);
768   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "Norm",
769                                 CeedVectorNorm_Hip); CeedChkBackend(ierr);
770   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal",
771                                 CeedVectorReciprocal_Hip); CeedChkBackend(ierr);
772   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "Scale",
773                                 (int (*)())(CeedVectorScale_Hip)); CeedChkBackend(ierr);
774   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "AXPY",
775                                 (int (*)())(CeedVectorAXPY_Hip)); CeedChkBackend(ierr);
776   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult",
777                                 CeedVectorPointwiseMult_Hip); CeedChkBackend(ierr);
778   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "Destroy",
779                                 CeedVectorDestroy_Hip); CeedChkBackend(ierr);
780 
781   ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr);
782   ierr = CeedVectorSetData(vec, impl); CeedChkBackend(ierr);
783 
784   return CEED_ERROR_SUCCESS;
785 }
786