xref: /libCEED/backends/hip-ref/ceed-hip-ref-vector.c (revision fa5adaf534a16ab6b728a4ab386a8a8b7691fe89)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <math.h>
11 #include <stdbool.h>
12 #include <string.h>
13 #include <hip/hip_runtime.h>
14 
15 #include "../hip/ceed-hip-common.h"
16 #include "ceed-hip-ref.h"
17 
18 //------------------------------------------------------------------------------
19 // Check if host/device sync is needed
20 //------------------------------------------------------------------------------
21 static inline int CeedVectorNeedSync_Hip(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
22   CeedVector_Hip *impl;
23   CeedCallBackend(CeedVectorGetData(vec, &impl));
24 
25   bool has_valid_array = false;
26   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
27   switch (mem_type) {
28     case CEED_MEM_HOST:
29       *need_sync = has_valid_array && !impl->h_array;
30       break;
31     case CEED_MEM_DEVICE:
32       *need_sync = has_valid_array && !impl->d_array;
33       break;
34   }
35 
36   return CEED_ERROR_SUCCESS;
37 }
38 
39 //------------------------------------------------------------------------------
40 // Sync host to device
41 //------------------------------------------------------------------------------
42 static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) {
43   Ceed ceed;
44   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
45   CeedVector_Hip *impl;
46   CeedCallBackend(CeedVectorGetData(vec, &impl));
47 
48   CeedSize length;
49   CeedCallBackend(CeedVectorGetLength(vec, &length));
50   size_t bytes = length * sizeof(CeedScalar);
51 
52   CeedCheck(impl->h_array, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
53 
54   if (impl->d_array_borrowed) {
55     impl->d_array = impl->d_array_borrowed;
56   } else if (impl->d_array_owned) {
57     impl->d_array = impl->d_array_owned;
58   } else {
59     CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes));
60     impl->d_array = impl->d_array_owned;
61   }
62 
63   CeedCallHip(ceed, hipMemcpy(impl->d_array, impl->h_array, bytes, hipMemcpyHostToDevice));
64 
65   return CEED_ERROR_SUCCESS;
66 }
67 
68 //------------------------------------------------------------------------------
69 // Sync device to host
70 //------------------------------------------------------------------------------
71 static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
72   Ceed ceed;
73   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
74   CeedVector_Hip *impl;
75   CeedCallBackend(CeedVectorGetData(vec, &impl));
76 
77   CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
78 
79   if (impl->h_array_borrowed) {
80     impl->h_array = impl->h_array_borrowed;
81   } else if (impl->h_array_owned) {
82     impl->h_array = impl->h_array_owned;
83   } else {
84     CeedSize length;
85     CeedCallBackend(CeedVectorGetLength(vec, &length));
86     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
87     impl->h_array = impl->h_array_owned;
88   }
89 
90   CeedSize length;
91   CeedCallBackend(CeedVectorGetLength(vec, &length));
92   size_t bytes = length * sizeof(CeedScalar);
93   CeedCallHip(ceed, hipMemcpy(impl->h_array, impl->d_array, bytes, hipMemcpyDeviceToHost));
94 
95   return CEED_ERROR_SUCCESS;
96 }
97 
98 //------------------------------------------------------------------------------
99 // Sync arrays
100 //------------------------------------------------------------------------------
101 static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) {
102   // Check whether device/host sync is needed
103   bool need_sync = false;
104   CeedCallBackend(CeedVectorNeedSync_Hip(vec, mem_type, &need_sync));
105   if (!need_sync) return CEED_ERROR_SUCCESS;
106 
107   switch (mem_type) {
108     case CEED_MEM_HOST:
109       return CeedVectorSyncD2H_Hip(vec);
110     case CEED_MEM_DEVICE:
111       return CeedVectorSyncH2D_Hip(vec);
112   }
113   return CEED_ERROR_UNSUPPORTED;
114 }
115 
116 //------------------------------------------------------------------------------
117 // Set all pointers as invalid
118 //------------------------------------------------------------------------------
119 static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) {
120   CeedVector_Hip *impl;
121   CeedCallBackend(CeedVectorGetData(vec, &impl));
122 
123   impl->h_array = NULL;
124   impl->d_array = NULL;
125 
126   return CEED_ERROR_SUCCESS;
127 }
128 
129 //------------------------------------------------------------------------------
130 // Check if CeedVector has any valid pointers
131 //------------------------------------------------------------------------------
132 static inline int CeedVectorHasValidArray_Hip(const CeedVector vec, bool *has_valid_array) {
133   CeedVector_Hip *impl;
134   CeedCallBackend(CeedVectorGetData(vec, &impl));
135 
136   *has_valid_array = !!impl->h_array || !!impl->d_array;
137 
138   return CEED_ERROR_SUCCESS;
139 }
140 
141 //------------------------------------------------------------------------------
142 // Check if has any array of given type
143 //------------------------------------------------------------------------------
144 static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
145   CeedVector_Hip *impl;
146   CeedCallBackend(CeedVectorGetData(vec, &impl));
147 
148   switch (mem_type) {
149     case CEED_MEM_HOST:
150       *has_array_of_type = !!impl->h_array_borrowed || !!impl->h_array_owned;
151       break;
152     case CEED_MEM_DEVICE:
153       *has_array_of_type = !!impl->d_array_borrowed || !!impl->d_array_owned;
154       break;
155   }
156 
157   return CEED_ERROR_SUCCESS;
158 }
159 
160 //------------------------------------------------------------------------------
161 // Check if has borrowed array of given type
162 //------------------------------------------------------------------------------
163 static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
164   CeedVector_Hip *impl;
165   CeedCallBackend(CeedVectorGetData(vec, &impl));
166 
167   switch (mem_type) {
168     case CEED_MEM_HOST:
169       *has_borrowed_array_of_type = !!impl->h_array_borrowed;
170       break;
171     case CEED_MEM_DEVICE:
172       *has_borrowed_array_of_type = !!impl->d_array_borrowed;
173       break;
174   }
175 
176   return CEED_ERROR_SUCCESS;
177 }
178 
179 //------------------------------------------------------------------------------
180 // Set array from host
181 //------------------------------------------------------------------------------
182 static int CeedVectorSetArrayHost_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
183   CeedVector_Hip *impl;
184   CeedCallBackend(CeedVectorGetData(vec, &impl));
185 
186   switch (copy_mode) {
187     case CEED_COPY_VALUES: {
188       CeedSize length;
189       if (!impl->h_array_owned) {
190         CeedCallBackend(CeedVectorGetLength(vec, &length));
191         CeedCallBackend(CeedMalloc(length, &impl->h_array_owned));
192       }
193       impl->h_array_borrowed = NULL;
194       impl->h_array          = impl->h_array_owned;
195       if (array) {
196         CeedSize length;
197         CeedCallBackend(CeedVectorGetLength(vec, &length));
198         size_t bytes = length * sizeof(CeedScalar);
199         memcpy(impl->h_array, array, bytes);
200       }
201     } break;
202     case CEED_OWN_POINTER:
203       CeedCallBackend(CeedFree(&impl->h_array_owned));
204       impl->h_array_owned    = array;
205       impl->h_array_borrowed = NULL;
206       impl->h_array          = array;
207       break;
208     case CEED_USE_POINTER:
209       CeedCallBackend(CeedFree(&impl->h_array_owned));
210       impl->h_array_borrowed = array;
211       impl->h_array          = array;
212       break;
213   }
214 
215   return CEED_ERROR_SUCCESS;
216 }
217 
218 //------------------------------------------------------------------------------
219 // Set array from device
220 //------------------------------------------------------------------------------
221 static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
222   Ceed ceed;
223   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
224   CeedVector_Hip *impl;
225   CeedCallBackend(CeedVectorGetData(vec, &impl));
226 
227   switch (copy_mode) {
228     case CEED_COPY_VALUES: {
229       CeedSize length;
230       CeedCallBackend(CeedVectorGetLength(vec, &length));
231       size_t bytes = length * sizeof(CeedScalar);
232       if (!impl->d_array_owned) {
233         CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes));
234       }
235       impl->d_array_borrowed = NULL;
236       impl->d_array          = impl->d_array_owned;
237       if (array) {
238         CeedCallHip(ceed, hipMemcpy(impl->d_array, array, bytes, hipMemcpyDeviceToDevice));
239       }
240     } break;
241     case CEED_OWN_POINTER:
242       CeedCallHip(ceed, hipFree(impl->d_array_owned));
243       impl->d_array_owned    = array;
244       impl->d_array_borrowed = NULL;
245       impl->d_array          = array;
246       break;
247     case CEED_USE_POINTER:
248       CeedCallHip(ceed, hipFree(impl->d_array_owned));
249       impl->d_array_owned    = NULL;
250       impl->d_array_borrowed = array;
251       impl->d_array          = array;
252       break;
253   }
254 
255   return CEED_ERROR_SUCCESS;
256 }
257 
258 //------------------------------------------------------------------------------
259 // Set the array used by a vector,
260 //   freeing any previously allocated array if applicable
261 //------------------------------------------------------------------------------
262 static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
263   Ceed ceed;
264   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
265   CeedVector_Hip *impl;
266   CeedCallBackend(CeedVectorGetData(vec, &impl));
267 
268   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
269   switch (mem_type) {
270     case CEED_MEM_HOST:
271       return CeedVectorSetArrayHost_Hip(vec, copy_mode, array);
272     case CEED_MEM_DEVICE:
273       return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array);
274   }
275 
276   return CEED_ERROR_UNSUPPORTED;
277 }
278 
279 //------------------------------------------------------------------------------
280 // Set host array to value
281 //------------------------------------------------------------------------------
282 static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedSize length, CeedScalar val) {
283   for (CeedSize i = 0; i < length; i++) h_array[i] = val;
284   return CEED_ERROR_SUCCESS;
285 }
286 
287 //------------------------------------------------------------------------------
288 // Set device array to value (impl in .hip file)
289 //------------------------------------------------------------------------------
290 int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val);
291 
292 //------------------------------------------------------------------------------
293 // Set a vector to a value,
294 //------------------------------------------------------------------------------
295 static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
296   Ceed ceed;
297   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
298   CeedVector_Hip *impl;
299   CeedCallBackend(CeedVectorGetData(vec, &impl));
300   CeedSize length;
301   CeedCallBackend(CeedVectorGetLength(vec, &length));
302 
303   // Set value for synced device/host array
304   if (!impl->d_array && !impl->h_array) {
305     if (impl->d_array_borrowed) {
306       impl->d_array = impl->d_array_borrowed;
307     } else if (impl->h_array_borrowed) {
308       impl->h_array = impl->h_array_borrowed;
309     } else if (impl->d_array_owned) {
310       impl->d_array = impl->d_array_owned;
311     } else if (impl->h_array_owned) {
312       impl->h_array = impl->h_array_owned;
313     } else {
314       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
315     }
316   }
317   if (impl->d_array) {
318     CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val));
319   }
320   if (impl->h_array) {
321     CeedCallBackend(CeedHostSetValue_Hip(impl->h_array, length, val));
322   }
323 
324   return CEED_ERROR_SUCCESS;
325 }
326 
327 //------------------------------------------------------------------------------
328 // Vector Take Array
329 //------------------------------------------------------------------------------
330 static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
331   Ceed ceed;
332   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
333   CeedVector_Hip *impl;
334   CeedCallBackend(CeedVectorGetData(vec, &impl));
335 
336   // Sync array to requested mem_type
337   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
338 
339   // Update pointer
340   switch (mem_type) {
341     case CEED_MEM_HOST:
342       (*array)               = impl->h_array_borrowed;
343       impl->h_array_borrowed = NULL;
344       impl->h_array          = NULL;
345       break;
346     case CEED_MEM_DEVICE:
347       (*array)               = impl->d_array_borrowed;
348       impl->d_array_borrowed = NULL;
349       impl->d_array          = NULL;
350       break;
351   }
352 
353   return CEED_ERROR_SUCCESS;
354 }
355 
356 //------------------------------------------------------------------------------
357 // Core logic for array syncronization for GetArray.
358 //   If a different memory type is most up to date, this will perform a copy
359 //------------------------------------------------------------------------------
360 static int CeedVectorGetArrayCore_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
361   Ceed ceed;
362   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
363   CeedVector_Hip *impl;
364   CeedCallBackend(CeedVectorGetData(vec, &impl));
365 
366   // Sync array to requested mem_type
367   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
368 
369   // Update pointer
370   switch (mem_type) {
371     case CEED_MEM_HOST:
372       *array = impl->h_array;
373       break;
374     case CEED_MEM_DEVICE:
375       *array = impl->d_array;
376       break;
377   }
378 
379   return CEED_ERROR_SUCCESS;
380 }
381 
382 //------------------------------------------------------------------------------
383 // Get read-only access to a vector via the specified mem_type
384 //------------------------------------------------------------------------------
385 static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
386   return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array);
387 }
388 
389 //------------------------------------------------------------------------------
390 // Get read/write access to a vector via the specified mem_type
391 //------------------------------------------------------------------------------
392 static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
393   CeedVector_Hip *impl;
394   CeedCallBackend(CeedVectorGetData(vec, &impl));
395 
396   CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, mem_type, array));
397 
398   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
399   switch (mem_type) {
400     case CEED_MEM_HOST:
401       impl->h_array = *array;
402       break;
403     case CEED_MEM_DEVICE:
404       impl->d_array = *array;
405       break;
406   }
407 
408   return CEED_ERROR_SUCCESS;
409 }
410 
411 //------------------------------------------------------------------------------
412 // Get write access to a vector via the specified mem_type
413 //------------------------------------------------------------------------------
414 static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
415   CeedVector_Hip *impl;
416   CeedCallBackend(CeedVectorGetData(vec, &impl));
417 
418   bool has_array_of_type = true;
419   CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type));
420   if (!has_array_of_type) {
421     // Allocate if array is not yet allocated
422     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
423   } else {
424     // Select dirty array
425     switch (mem_type) {
426       case CEED_MEM_HOST:
427         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
428         else impl->h_array = impl->h_array_owned;
429         break;
430       case CEED_MEM_DEVICE:
431         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
432         else impl->d_array = impl->d_array_owned;
433     }
434   }
435 
436   return CeedVectorGetArray_Hip(vec, mem_type, array);
437 }
438 
439 //------------------------------------------------------------------------------
440 // Get the norm of a CeedVector
441 //------------------------------------------------------------------------------
442 static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *norm) {
443   Ceed ceed;
444   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
445   CeedVector_Hip *impl;
446   CeedCallBackend(CeedVectorGetData(vec, &impl));
447   CeedSize length;
448   CeedCallBackend(CeedVectorGetLength(vec, &length));
449   hipblasHandle_t handle;
450   CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
451 
452   // Is the vector too long to handle with int32? If so, we will divide
453   // it up into "int32-sized" subsections and make repeated BLAS calls.
454   CeedSize num_calls = length / INT_MAX;
455   if (length % INT_MAX > 0) num_calls += 1;
456 
457   // Compute norm
458   const CeedScalar *d_array;
459   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
460   switch (type) {
461     case CEED_NORM_1: {
462       *norm = 0.0;
463       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
464         float  sub_norm = 0.0;
465         float *d_array_start;
466         for (CeedInt i = 0; i < num_calls; i++) {
467           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
468           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
469           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
470           CeedCallHipblas(ceed, hipblasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
471           *norm += sub_norm;
472         }
473       } else {
474         double  sub_norm = 0.0;
475         double *d_array_start;
476         for (CeedInt i = 0; i < num_calls; i++) {
477           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
478           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
479           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
480           CeedCallHipblas(ceed, hipblasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
481           *norm += sub_norm;
482         }
483       }
484       break;
485     }
486     case CEED_NORM_2: {
487       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
488         float  sub_norm = 0.0, norm_sum = 0.0;
489         float *d_array_start;
490         for (CeedInt i = 0; i < num_calls; i++) {
491           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
492           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
493           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
494           CeedCallHipblas(ceed, hipblasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
495           norm_sum += sub_norm * sub_norm;
496         }
497         *norm = sqrt(norm_sum);
498       } else {
499         double  sub_norm = 0.0, norm_sum = 0.0;
500         double *d_array_start;
501         for (CeedInt i = 0; i < num_calls; i++) {
502           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
503           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
504           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
505           CeedCallHipblas(ceed, hipblasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
506           norm_sum += sub_norm * sub_norm;
507         }
508         *norm = sqrt(norm_sum);
509       }
510       break;
511     }
512     case CEED_NORM_MAX: {
513       CeedInt indx;
514       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
515         float  sub_max = 0.0, current_max = 0.0;
516         float *d_array_start;
517         for (CeedInt i = 0; i < num_calls; i++) {
518           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
519           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
520           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
521           CeedCallHipblas(ceed, hipblasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &indx));
522           CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + indx - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
523           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
524         }
525         *norm = current_max;
526       } else {
527         double  sub_max = 0.0, current_max = 0.0;
528         double *d_array_start;
529         for (CeedInt i = 0; i < num_calls; i++) {
530           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
531           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
532           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
533           CeedCallHipblas(ceed, hipblasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &indx));
534           CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + indx - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
535           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
536         }
537         *norm = current_max;
538       }
539       break;
540     }
541   }
542   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
543 
544   return CEED_ERROR_SUCCESS;
545 }
546 
547 //------------------------------------------------------------------------------
548 // Take reciprocal of a vector on host
549 //------------------------------------------------------------------------------
550 static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedSize length) {
551   for (CeedSize i = 0; i < length; i++) {
552     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
553   }
554   return CEED_ERROR_SUCCESS;
555 }
556 
557 //------------------------------------------------------------------------------
558 // Take reciprocal of a vector on device (impl in .cu file)
559 //------------------------------------------------------------------------------
560 int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedSize length);
561 
562 //------------------------------------------------------------------------------
563 // Take reciprocal of a vector
564 //------------------------------------------------------------------------------
565 static int CeedVectorReciprocal_Hip(CeedVector vec) {
566   Ceed ceed;
567   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
568   CeedVector_Hip *impl;
569   CeedCallBackend(CeedVectorGetData(vec, &impl));
570   CeedSize length;
571   CeedCallBackend(CeedVectorGetLength(vec, &length));
572 
573   // Set value for synced device/host array
574   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Hip(impl->d_array, length));
575   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length));
576 
577   return CEED_ERROR_SUCCESS;
578 }
579 
580 //------------------------------------------------------------------------------
581 // Compute x = alpha x on the host
582 //------------------------------------------------------------------------------
583 static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length) {
584   for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha;
585   return CEED_ERROR_SUCCESS;
586 }
587 
588 //------------------------------------------------------------------------------
589 // Compute x = alpha x on device (impl in .cu file)
590 //------------------------------------------------------------------------------
591 int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length);
592 
593 //------------------------------------------------------------------------------
594 // Compute x = alpha x
595 //------------------------------------------------------------------------------
596 static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
597   Ceed ceed;
598   CeedCallBackend(CeedVectorGetCeed(x, &ceed));
599   CeedVector_Hip *x_impl;
600   CeedCallBackend(CeedVectorGetData(x, &x_impl));
601   CeedSize length;
602   CeedCallBackend(CeedVectorGetLength(x, &length));
603 
604   // Set value for synced device/host array
605   if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Hip(x_impl->d_array, alpha, length));
606   if (x_impl->h_array) CeedCallBackend(CeedHostScale_Hip(x_impl->h_array, alpha, length));
607 
608   return CEED_ERROR_SUCCESS;
609 }
610 
611 //------------------------------------------------------------------------------
612 // Compute y = alpha x + y on the host
613 //------------------------------------------------------------------------------
614 static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) {
615   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
616   return CEED_ERROR_SUCCESS;
617 }
618 
619 //------------------------------------------------------------------------------
620 // Compute y = alpha x + y on device (impl in .cu file)
621 //------------------------------------------------------------------------------
622 int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length);
623 
624 //------------------------------------------------------------------------------
625 // Compute y = alpha x + y
626 //------------------------------------------------------------------------------
627 static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
628   Ceed ceed;
629   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
630   CeedVector_Hip *y_impl, *x_impl;
631   CeedCallBackend(CeedVectorGetData(y, &y_impl));
632   CeedCallBackend(CeedVectorGetData(x, &x_impl));
633   CeedSize length;
634   CeedCallBackend(CeedVectorGetLength(y, &length));
635 
636   // Set value for synced device/host array
637   if (y_impl->d_array) {
638     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
639     CeedCallBackend(CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length));
640   }
641   if (y_impl->h_array) {
642     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
643     CeedCallBackend(CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length));
644   }
645 
646   return CEED_ERROR_SUCCESS;
647 }
648 
649 //------------------------------------------------------------------------------
650 // Compute y = alpha x + beta y on the host
651 //------------------------------------------------------------------------------
652 static int CeedHostAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) {
653   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i] + beta * y_array[i];
654   return CEED_ERROR_SUCCESS;
655 }
656 
657 //------------------------------------------------------------------------------
658 // Compute y = alpha x + beta y on device (impl in .cu file)
659 //------------------------------------------------------------------------------
660 int CeedDeviceAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length);
661 
662 //------------------------------------------------------------------------------
663 // Compute y = alpha x + beta y
664 //------------------------------------------------------------------------------
665 static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
666   Ceed ceed;
667   CeedCallBackend(CeedVectorGetCeed(y, &ceed));
668   CeedVector_Hip *y_impl, *x_impl;
669   CeedCallBackend(CeedVectorGetData(y, &y_impl));
670   CeedCallBackend(CeedVectorGetData(x, &x_impl));
671   CeedSize length;
672   CeedCallBackend(CeedVectorGetLength(y, &length));
673 
674   // Set value for synced device/host array
675   if (y_impl->d_array) {
676     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
677     CeedCallBackend(CeedDeviceAXPBY_Hip(y_impl->d_array, alpha, beta, x_impl->d_array, length));
678   }
679   if (y_impl->h_array) {
680     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
681     CeedCallBackend(CeedHostAXPBY_Hip(y_impl->h_array, alpha, beta, x_impl->h_array, length));
682   }
683 
684   return CEED_ERROR_SUCCESS;
685 }
686 
687 //------------------------------------------------------------------------------
688 // Compute the pointwise multiplication w = x .* y on the host
689 //------------------------------------------------------------------------------
690 static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) {
691   for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
692   return CEED_ERROR_SUCCESS;
693 }
694 
695 //------------------------------------------------------------------------------
696 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
697 //------------------------------------------------------------------------------
698 int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length);
699 
700 //------------------------------------------------------------------------------
701 // Compute the pointwise multiplication w = x .* y
702 //------------------------------------------------------------------------------
703 static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y) {
704   Ceed ceed;
705   CeedCallBackend(CeedVectorGetCeed(w, &ceed));
706   CeedVector_Hip *w_impl, *x_impl, *y_impl;
707   CeedCallBackend(CeedVectorGetData(w, &w_impl));
708   CeedCallBackend(CeedVectorGetData(x, &x_impl));
709   CeedCallBackend(CeedVectorGetData(y, &y_impl));
710   CeedSize length;
711   CeedCallBackend(CeedVectorGetLength(w, &length));
712 
713   // Set value for synced device/host array
714   if (!w_impl->d_array && !w_impl->h_array) {
715     CeedCallBackend(CeedVectorSetValue(w, 0.0));
716   }
717   if (w_impl->d_array) {
718     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
719     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
720     CeedCallBackend(CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
721   }
722   if (w_impl->h_array) {
723     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
724     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
725     CeedCallBackend(CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
726   }
727 
728   return CEED_ERROR_SUCCESS;
729 }
730 
731 //------------------------------------------------------------------------------
732 // Destroy the vector
733 //------------------------------------------------------------------------------
734 static int CeedVectorDestroy_Hip(const CeedVector vec) {
735   Ceed ceed;
736   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
737   CeedVector_Hip *impl;
738   CeedCallBackend(CeedVectorGetData(vec, &impl));
739 
740   CeedCallHip(ceed, hipFree(impl->d_array_owned));
741   CeedCallBackend(CeedFree(&impl->h_array_owned));
742   CeedCallBackend(CeedFree(&impl));
743 
744   return CEED_ERROR_SUCCESS;
745 }
746 
747 //------------------------------------------------------------------------------
748 // Create a vector of the specified length (does not allocate memory)
749 //------------------------------------------------------------------------------
750 int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
751   CeedVector_Hip *impl;
752   Ceed            ceed;
753   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
754 
755   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Hip));
756   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Hip));
757   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Hip));
758   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Hip));
759   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", (int (*)())CeedVectorSetValue_Hip));
760   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Hip));
761   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Hip));
762   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Hip));
763   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Hip));
764   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Hip));
765   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Hip));
766   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", (int (*)())CeedVectorScale_Hip));
767   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", (int (*)())CeedVectorAXPY_Hip));
768   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", (int (*)())CeedVectorAXPBY_Hip));
769   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Hip));
770   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Hip));
771 
772   CeedCallBackend(CeedCalloc(1, &impl));
773   CeedCallBackend(CeedVectorSetData(vec, impl));
774 
775   return CEED_ERROR_SUCCESS;
776 }
777 
778 //------------------------------------------------------------------------------
779