xref: /libCEED/backends/hip-ref/ceed-hip-ref-vector.c (revision f07714d9267ae37c0c844e07b7f7e93ea3da4ade)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <math.h>
11 #include <stdbool.h>
12 #include <string.h>
13 #include <hip/hip_runtime.h>
14 
15 #include "../hip/ceed-hip-common.h"
16 #include "ceed-hip-ref.h"
17 
18 //------------------------------------------------------------------------------
19 // Check if host/device sync is needed
20 //------------------------------------------------------------------------------
21 static inline int CeedVectorNeedSync_Hip(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
22   CeedVector_Hip *impl;
23   bool            has_valid_array = false;
24 
25   CeedCallBackend(CeedVectorGetData(vec, &impl));
26   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
27   switch (mem_type) {
28     case CEED_MEM_HOST:
29       *need_sync = has_valid_array && !impl->h_array;
30       break;
31     case CEED_MEM_DEVICE:
32       *need_sync = has_valid_array && !impl->d_array;
33       break;
34   }
35   return CEED_ERROR_SUCCESS;
36 }
37 
38 //------------------------------------------------------------------------------
39 // Sync host to device
40 //------------------------------------------------------------------------------
41 static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) {
42   Ceed            ceed;
43   CeedSize        length;
44   size_t          bytes;
45   CeedVector_Hip *impl;
46 
47   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
48   CeedCallBackend(CeedVectorGetData(vec, &impl));
49 
50   CeedCheck(impl->h_array, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
51 
52   CeedCallBackend(CeedVectorGetLength(vec, &length));
53   bytes = length * sizeof(CeedScalar);
54   if (impl->d_array_borrowed) {
55     impl->d_array = impl->d_array_borrowed;
56   } else if (impl->d_array_owned) {
57     impl->d_array = impl->d_array_owned;
58   } else {
59     CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes));
60     impl->d_array = impl->d_array_owned;
61   }
62   CeedCallHip(ceed, hipMemcpy(impl->d_array, impl->h_array, bytes, hipMemcpyHostToDevice));
63   return CEED_ERROR_SUCCESS;
64 }
65 
66 //------------------------------------------------------------------------------
67 // Sync device to host
68 //------------------------------------------------------------------------------
69 static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
70   Ceed            ceed;
71   CeedSize        length;
72   size_t          bytes;
73   CeedVector_Hip *impl;
74 
75   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
76   CeedCallBackend(CeedVectorGetData(vec, &impl));
77 
78   CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
79 
80   if (impl->h_array_borrowed) {
81     impl->h_array = impl->h_array_borrowed;
82   } else if (impl->h_array_owned) {
83     impl->h_array = impl->h_array_owned;
84   } else {
85     CeedSize length;
86 
87     CeedCallBackend(CeedVectorGetLength(vec, &length));
88     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
89     impl->h_array = impl->h_array_owned;
90   }
91 
92   CeedCallBackend(CeedVectorGetLength(vec, &length));
93   bytes = length * sizeof(CeedScalar);
94   CeedCallHip(ceed, hipMemcpy(impl->h_array, impl->d_array, bytes, hipMemcpyDeviceToHost));
95   return CEED_ERROR_SUCCESS;
96 }
97 
98 //------------------------------------------------------------------------------
99 // Sync arrays
100 //------------------------------------------------------------------------------
101 static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) {
102   bool need_sync = false;
103 
104   // Check whether device/host sync is needed
105   CeedCallBackend(CeedVectorNeedSync_Hip(vec, mem_type, &need_sync));
106   if (!need_sync) return CEED_ERROR_SUCCESS;
107 
108   switch (mem_type) {
109     case CEED_MEM_HOST:
110       return CeedVectorSyncD2H_Hip(vec);
111     case CEED_MEM_DEVICE:
112       return CeedVectorSyncH2D_Hip(vec);
113   }
114   return CEED_ERROR_UNSUPPORTED;
115 }
116 
117 //------------------------------------------------------------------------------
118 // Set all pointers as invalid
119 //------------------------------------------------------------------------------
120 static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) {
121   CeedVector_Hip *impl;
122 
123   CeedCallBackend(CeedVectorGetData(vec, &impl));
124   impl->h_array = NULL;
125   impl->d_array = NULL;
126   return CEED_ERROR_SUCCESS;
127 }
128 
129 //------------------------------------------------------------------------------
130 // Check if CeedVector has any valid pointer
131 //------------------------------------------------------------------------------
132 static inline int CeedVectorHasValidArray_Hip(const CeedVector vec, bool *has_valid_array) {
133   CeedVector_Hip *impl;
134 
135   CeedCallBackend(CeedVectorGetData(vec, &impl));
136   *has_valid_array = impl->h_array || impl->d_array;
137   return CEED_ERROR_SUCCESS;
138 }
139 
140 //------------------------------------------------------------------------------
141 // Check if has array of given type
142 //------------------------------------------------------------------------------
143 static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
144   CeedVector_Hip *impl;
145 
146   CeedCallBackend(CeedVectorGetData(vec, &impl));
147   switch (mem_type) {
148     case CEED_MEM_HOST:
149       *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned;
150       break;
151     case CEED_MEM_DEVICE:
152       *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned;
153       break;
154   }
155   return CEED_ERROR_SUCCESS;
156 }
157 
158 //------------------------------------------------------------------------------
159 // Check if has borrowed array of given type
160 //------------------------------------------------------------------------------
161 static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
162   CeedVector_Hip *impl;
163 
164   CeedCallBackend(CeedVectorGetData(vec, &impl));
165   switch (mem_type) {
166     case CEED_MEM_HOST:
167       *has_borrowed_array_of_type = impl->h_array_borrowed;
168       break;
169     case CEED_MEM_DEVICE:
170       *has_borrowed_array_of_type = impl->d_array_borrowed;
171       break;
172   }
173   return CEED_ERROR_SUCCESS;
174 }
175 
176 //------------------------------------------------------------------------------
177 // Set array from host
178 //------------------------------------------------------------------------------
179 static int CeedVectorSetArrayHost_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
180   CeedVector_Hip *impl;
181 
182   CeedCallBackend(CeedVectorGetData(vec, &impl));
183   switch (copy_mode) {
184     case CEED_COPY_VALUES: {
185       CeedSize length;
186 
187       if (!impl->h_array_owned) {
188         CeedCallBackend(CeedVectorGetLength(vec, &length));
189         CeedCallBackend(CeedMalloc(length, &impl->h_array_owned));
190       }
191       impl->h_array_borrowed = NULL;
192       impl->h_array          = impl->h_array_owned;
193       if (array) {
194         CeedSize length;
195         size_t   bytes;
196 
197         CeedCallBackend(CeedVectorGetLength(vec, &length));
198         bytes = length * sizeof(CeedScalar);
199         memcpy(impl->h_array, array, bytes);
200       }
201     } break;
202     case CEED_OWN_POINTER:
203       CeedCallBackend(CeedFree(&impl->h_array_owned));
204       impl->h_array_owned    = array;
205       impl->h_array_borrowed = NULL;
206       impl->h_array          = array;
207       break;
208     case CEED_USE_POINTER:
209       CeedCallBackend(CeedFree(&impl->h_array_owned));
210       impl->h_array_borrowed = array;
211       impl->h_array          = array;
212       break;
213   }
214   return CEED_ERROR_SUCCESS;
215 }
216 
217 //------------------------------------------------------------------------------
218 // Set array from device
219 //------------------------------------------------------------------------------
220 static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
221   Ceed            ceed;
222   CeedVector_Hip *impl;
223 
224   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
225   CeedCallBackend(CeedVectorGetData(vec, &impl));
226   switch (copy_mode) {
227     case CEED_COPY_VALUES: {
228       CeedSize length;
229       size_t   bytes;
230 
231       CeedCallBackend(CeedVectorGetLength(vec, &length));
232       bytes = length * sizeof(CeedScalar);
233       if (!impl->d_array_owned) {
234         CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes));
235       }
236       impl->d_array_borrowed = NULL;
237       impl->d_array          = impl->d_array_owned;
238       if (array) CeedCallHip(ceed, hipMemcpy(impl->d_array, array, bytes, hipMemcpyDeviceToDevice));
239     } break;
240     case CEED_OWN_POINTER:
241       CeedCallHip(ceed, hipFree(impl->d_array_owned));
242       impl->d_array_owned    = array;
243       impl->d_array_borrowed = NULL;
244       impl->d_array          = array;
245       break;
246     case CEED_USE_POINTER:
247       CeedCallHip(ceed, hipFree(impl->d_array_owned));
248       impl->d_array_owned    = NULL;
249       impl->d_array_borrowed = array;
250       impl->d_array          = array;
251       break;
252   }
253   return CEED_ERROR_SUCCESS;
254 }
255 
256 //------------------------------------------------------------------------------
257 // Set the array used by a vector,
258 //   freeing any previously allocated array if applicable
259 //------------------------------------------------------------------------------
260 static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
261   CeedVector_Hip *impl;
262 
263   CeedCallBackend(CeedVectorGetData(vec, &impl));
264   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
265   switch (mem_type) {
266     case CEED_MEM_HOST:
267       return CeedVectorSetArrayHost_Hip(vec, copy_mode, array);
268     case CEED_MEM_DEVICE:
269       return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array);
270   }
271   return CEED_ERROR_UNSUPPORTED;
272 }
273 
274 //------------------------------------------------------------------------------
275 // Set host array to value
276 //------------------------------------------------------------------------------
277 static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedSize length, CeedScalar val) {
278   for (CeedSize i = 0; i < length; i++) h_array[i] = val;
279   return CEED_ERROR_SUCCESS;
280 }
281 
282 //------------------------------------------------------------------------------
283 // Set device array to value (impl in .hip file)
284 //------------------------------------------------------------------------------
285 int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val);
286 
287 //------------------------------------------------------------------------------
288 // Set a vector to a value
289 //------------------------------------------------------------------------------
290 static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
291   CeedSize        length;
292   CeedVector_Hip *impl;
293 
294   CeedCallBackend(CeedVectorGetData(vec, &impl));
295   CeedCallBackend(CeedVectorGetLength(vec, &length));
296   // Set value for synced device/host array
297   if (!impl->d_array && !impl->h_array) {
298     if (impl->d_array_borrowed) {
299       impl->d_array = impl->d_array_borrowed;
300     } else if (impl->h_array_borrowed) {
301       impl->h_array = impl->h_array_borrowed;
302     } else if (impl->d_array_owned) {
303       impl->d_array = impl->d_array_owned;
304     } else if (impl->h_array_owned) {
305       impl->h_array = impl->h_array_owned;
306     } else {
307       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
308     }
309   }
310   if (impl->d_array) {
311     CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val));
312     impl->h_array = NULL;
313   }
314   if (impl->h_array) {
315     CeedCallBackend(CeedHostSetValue_Hip(impl->h_array, length, val));
316     impl->d_array = NULL;
317   }
318   return CEED_ERROR_SUCCESS;
319 }
320 
321 //------------------------------------------------------------------------------
322 // Vector Take Array
323 //------------------------------------------------------------------------------
324 static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
325   CeedVector_Hip *impl;
326 
327   CeedCallBackend(CeedVectorGetData(vec, &impl));
328 
329   // Sync array to requested mem_type
330   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
331 
332   // Update pointer
333   switch (mem_type) {
334     case CEED_MEM_HOST:
335       (*array)               = impl->h_array_borrowed;
336       impl->h_array_borrowed = NULL;
337       impl->h_array          = NULL;
338       break;
339     case CEED_MEM_DEVICE:
340       (*array)               = impl->d_array_borrowed;
341       impl->d_array_borrowed = NULL;
342       impl->d_array          = NULL;
343       break;
344   }
345   return CEED_ERROR_SUCCESS;
346 }
347 
348 //------------------------------------------------------------------------------
349 // Core logic for array syncronization for GetArray.
350 //   If a different memory type is most up to date, this will perform a copy
351 //------------------------------------------------------------------------------
352 static int CeedVectorGetArrayCore_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
353   CeedVector_Hip *impl;
354 
355   CeedCallBackend(CeedVectorGetData(vec, &impl));
356 
357   // Sync array to requested mem_type
358   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
359 
360   // Update pointer
361   switch (mem_type) {
362     case CEED_MEM_HOST:
363       *array = impl->h_array;
364       break;
365     case CEED_MEM_DEVICE:
366       *array = impl->d_array;
367       break;
368   }
369   return CEED_ERROR_SUCCESS;
370 }
371 
372 //------------------------------------------------------------------------------
373 // Get read-only access to a vector via the specified mem_type
374 //------------------------------------------------------------------------------
375 static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
376   return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array);
377 }
378 
379 //------------------------------------------------------------------------------
380 // Get read/write access to a vector via the specified mem_type
381 //------------------------------------------------------------------------------
382 static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
383   CeedVector_Hip *impl;
384 
385   CeedCallBackend(CeedVectorGetData(vec, &impl));
386   CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, mem_type, array));
387   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
388   switch (mem_type) {
389     case CEED_MEM_HOST:
390       impl->h_array = *array;
391       break;
392     case CEED_MEM_DEVICE:
393       impl->d_array = *array;
394       break;
395   }
396   return CEED_ERROR_SUCCESS;
397 }
398 
399 //------------------------------------------------------------------------------
400 // Get write access to a vector via the specified mem_type
401 //------------------------------------------------------------------------------
402 static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
403   bool            has_array_of_type = true;
404   CeedVector_Hip *impl;
405 
406   CeedCallBackend(CeedVectorGetData(vec, &impl));
407   CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type));
408   if (!has_array_of_type) {
409     // Allocate if array is not yet allocated
410     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
411   } else {
412     // Select dirty array
413     switch (mem_type) {
414       case CEED_MEM_HOST:
415         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
416         else impl->h_array = impl->h_array_owned;
417         break;
418       case CEED_MEM_DEVICE:
419         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
420         else impl->d_array = impl->d_array_owned;
421     }
422   }
423   return CeedVectorGetArray_Hip(vec, mem_type, array);
424 }
425 
426 //------------------------------------------------------------------------------
427 // Get the norm of a CeedVector
428 //------------------------------------------------------------------------------
429 static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *norm) {
430   Ceed              ceed;
431   CeedSize          length, num_calls;
432   const CeedScalar *d_array;
433   CeedVector_Hip   *impl;
434   hipblasHandle_t   handle;
435 
436   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
437   CeedCallBackend(CeedVectorGetData(vec, &impl));
438   CeedCallBackend(CeedVectorGetLength(vec, &length));
439   CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
440 
441   // Is the vector too long to handle with int32? If so, we will divide
442   // it up into "int32-sized" subsections and make repeated BLAS calls.
443   num_calls = length / INT_MAX;
444   if (length % INT_MAX > 0) num_calls += 1;
445 
446   // Compute norm
447   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
448   switch (type) {
449     case CEED_NORM_1: {
450       *norm = 0.0;
451       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
452         float  sub_norm = 0.0;
453         float *d_array_start;
454 
455         for (CeedInt i = 0; i < num_calls; i++) {
456           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
457           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
458           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
459 
460           CeedCallHipblas(ceed, hipblasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
461           *norm += sub_norm;
462         }
463       } else {
464         double  sub_norm = 0.0;
465         double *d_array_start;
466 
467         for (CeedInt i = 0; i < num_calls; i++) {
468           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
469           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
470           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
471 
472           CeedCallHipblas(ceed, hipblasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
473           *norm += sub_norm;
474         }
475       }
476       break;
477     }
478     case CEED_NORM_2: {
479       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
480         float  sub_norm = 0.0, norm_sum = 0.0;
481         float *d_array_start;
482 
483         for (CeedInt i = 0; i < num_calls; i++) {
484           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
485           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
486           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
487 
488           CeedCallHipblas(ceed, hipblasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
489           norm_sum += sub_norm * sub_norm;
490         }
491         *norm = sqrt(norm_sum);
492       } else {
493         double  sub_norm = 0.0, norm_sum = 0.0;
494         double *d_array_start;
495 
496         for (CeedInt i = 0; i < num_calls; i++) {
497           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
498           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
499           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
500 
501           CeedCallHipblas(ceed, hipblasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
502           norm_sum += sub_norm * sub_norm;
503         }
504         *norm = sqrt(norm_sum);
505       }
506       break;
507     }
508     case CEED_NORM_MAX: {
509       CeedInt index;
510 
511       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
512         float  sub_max = 0.0, current_max = 0.0;
513         float *d_array_start;
514         for (CeedInt i = 0; i < num_calls; i++) {
515           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
516           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
517           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
518 
519           CeedCallHipblas(ceed, hipblasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &index));
520           CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
521           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
522         }
523         *norm = current_max;
524       } else {
525         double  sub_max = 0.0, current_max = 0.0;
526         double *d_array_start;
527 
528         for (CeedInt i = 0; i < num_calls; i++) {
529           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
530           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
531           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
532 
533           CeedCallHipblas(ceed, hipblasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &index));
534           CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
535           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
536         }
537         *norm = current_max;
538       }
539       break;
540     }
541   }
542   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
543   return CEED_ERROR_SUCCESS;
544 }
545 
546 //------------------------------------------------------------------------------
547 // Take reciprocal of a vector on host
548 //------------------------------------------------------------------------------
549 static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedSize length) {
550   for (CeedSize i = 0; i < length; i++) {
551     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
552   }
553   return CEED_ERROR_SUCCESS;
554 }
555 
556 //------------------------------------------------------------------------------
557 // Take reciprocal of a vector on device (impl in .cu file)
558 //------------------------------------------------------------------------------
559 int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedSize length);
560 
561 //------------------------------------------------------------------------------
562 // Take reciprocal of a vector
563 //------------------------------------------------------------------------------
564 static int CeedVectorReciprocal_Hip(CeedVector vec) {
565   CeedSize        length;
566   CeedVector_Hip *impl;
567 
568   CeedCallBackend(CeedVectorGetData(vec, &impl));
569   CeedCallBackend(CeedVectorGetLength(vec, &length));
570   // Set value for synced device/host array
571   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Hip(impl->d_array, length));
572   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length));
573   return CEED_ERROR_SUCCESS;
574 }
575 
576 //------------------------------------------------------------------------------
577 // Compute x = alpha x on the host
578 //------------------------------------------------------------------------------
579 static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length) {
580   for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha;
581   return CEED_ERROR_SUCCESS;
582 }
583 
584 //------------------------------------------------------------------------------
585 // Compute x = alpha x on device (impl in .cu file)
586 //------------------------------------------------------------------------------
587 int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length);
588 
589 //------------------------------------------------------------------------------
590 // Compute x = alpha x
591 //------------------------------------------------------------------------------
592 static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
593   CeedSize        length;
594   CeedVector_Hip *x_impl;
595 
596   CeedCallBackend(CeedVectorGetData(x, &x_impl));
597   CeedCallBackend(CeedVectorGetLength(x, &length));
598   // Set value for synced device/host array
599   if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Hip(x_impl->d_array, alpha, length));
600   if (x_impl->h_array) CeedCallBackend(CeedHostScale_Hip(x_impl->h_array, alpha, length));
601   return CEED_ERROR_SUCCESS;
602 }
603 
604 //------------------------------------------------------------------------------
605 // Compute y = alpha x + y on the host
606 //------------------------------------------------------------------------------
607 static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) {
608   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
609   return CEED_ERROR_SUCCESS;
610 }
611 
612 //------------------------------------------------------------------------------
613 // Compute y = alpha x + y on device (impl in .cu file)
614 //------------------------------------------------------------------------------
615 int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length);
616 
617 //------------------------------------------------------------------------------
618 // Compute y = alpha x + y
619 //------------------------------------------------------------------------------
620 static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
621   CeedSize        length;
622   CeedVector_Hip *y_impl, *x_impl;
623 
624   CeedCallBackend(CeedVectorGetData(y, &y_impl));
625   CeedCallBackend(CeedVectorGetData(x, &x_impl));
626   CeedCallBackend(CeedVectorGetLength(y, &length));
627   // Set value for synced device/host array
628   if (y_impl->d_array) {
629     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
630     CeedCallBackend(CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length));
631   }
632   if (y_impl->h_array) {
633     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
634     CeedCallBackend(CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length));
635   }
636   return CEED_ERROR_SUCCESS;
637 }
638 
639 //------------------------------------------------------------------------------
640 // Compute y = alpha x + beta y on the host
641 //------------------------------------------------------------------------------
642 static int CeedHostAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) {
643   for (CeedSize i = 0; i < length; i++) y_array[i] = alpha * x_array[i] + beta * y_array[i];
644   return CEED_ERROR_SUCCESS;
645 }
646 
647 //------------------------------------------------------------------------------
648 // Compute y = alpha x + beta y on device (impl in .cu file)
649 //------------------------------------------------------------------------------
650 int CeedDeviceAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length);
651 
652 //------------------------------------------------------------------------------
653 // Compute y = alpha x + beta y
654 //------------------------------------------------------------------------------
655 static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
656   CeedSize        length;
657   CeedVector_Hip *y_impl, *x_impl;
658 
659   CeedCallBackend(CeedVectorGetData(y, &y_impl));
660   CeedCallBackend(CeedVectorGetData(x, &x_impl));
661   CeedCallBackend(CeedVectorGetLength(y, &length));
662   // Set value for synced device/host array
663   if (y_impl->d_array) {
664     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
665     CeedCallBackend(CeedDeviceAXPBY_Hip(y_impl->d_array, alpha, beta, x_impl->d_array, length));
666   }
667   if (y_impl->h_array) {
668     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
669     CeedCallBackend(CeedHostAXPBY_Hip(y_impl->h_array, alpha, beta, x_impl->h_array, length));
670   }
671   return CEED_ERROR_SUCCESS;
672 }
673 
674 //------------------------------------------------------------------------------
675 // Compute the pointwise multiplication w = x .* y on the host
676 //------------------------------------------------------------------------------
677 static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) {
678   for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
679   return CEED_ERROR_SUCCESS;
680 }
681 
682 //------------------------------------------------------------------------------
683 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
684 //------------------------------------------------------------------------------
685 int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length);
686 
687 //------------------------------------------------------------------------------
688 // Compute the pointwise multiplication w = x .* y
689 //------------------------------------------------------------------------------
690 static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y) {
691   CeedSize        length;
692   CeedVector_Hip *w_impl, *x_impl, *y_impl;
693 
694   CeedCallBackend(CeedVectorGetData(w, &w_impl));
695   CeedCallBackend(CeedVectorGetData(x, &x_impl));
696   CeedCallBackend(CeedVectorGetData(y, &y_impl));
697   CeedCallBackend(CeedVectorGetLength(w, &length));
698 
699   // Set value for synced device/host array
700   if (!w_impl->d_array && !w_impl->h_array) {
701     CeedCallBackend(CeedVectorSetValue(w, 0.0));
702   }
703   if (w_impl->d_array) {
704     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
705     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
706     CeedCallBackend(CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
707   }
708   if (w_impl->h_array) {
709     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
710     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
711     CeedCallBackend(CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
712   }
713   return CEED_ERROR_SUCCESS;
714 }
715 
716 //------------------------------------------------------------------------------
717 // Destroy the vector
718 //------------------------------------------------------------------------------
719 static int CeedVectorDestroy_Hip(const CeedVector vec) {
720   CeedVector_Hip *impl;
721 
722   CeedCallBackend(CeedVectorGetData(vec, &impl));
723   CeedCallHip(CeedVectorReturnCeed(vec), hipFree(impl->d_array_owned));
724   CeedCallBackend(CeedFree(&impl->h_array_owned));
725   CeedCallBackend(CeedFree(&impl));
726   return CEED_ERROR_SUCCESS;
727 }
728 
729 //------------------------------------------------------------------------------
730 // Create a vector of the specified length (does not allocate memory)
731 //------------------------------------------------------------------------------
732 int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
733   CeedVector_Hip *impl;
734   Ceed            ceed;
735 
736   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
737   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Hip));
738   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Hip));
739   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Hip));
740   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Hip));
741   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", (int (*)())CeedVectorSetValue_Hip));
742   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Hip));
743   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Hip));
744   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Hip));
745   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Hip));
746   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Hip));
747   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Hip));
748   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", (int (*)())CeedVectorScale_Hip));
749   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", (int (*)())CeedVectorAXPY_Hip));
750   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", (int (*)())CeedVectorAXPBY_Hip));
751   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Hip));
752   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Hip));
753   CeedCallBackend(CeedCalloc(1, &impl));
754   CeedCallBackend(CeedVectorSetData(vec, impl));
755   return CEED_ERROR_SUCCESS;
756 }
757 
758 //------------------------------------------------------------------------------
759