xref: /libCEED/backends/hip-ref/ceed-hip-ref-vector.c (revision 5aed82e4fa97acf4ba24a7f10a35f5303a6798e0)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <math.h>
11 #include <stdbool.h>
12 #include <string.h>
13 #include <hip/hip_runtime.h>
14 
15 #include "../hip/ceed-hip-common.h"
16 #include "ceed-hip-ref.h"
17 
18 //------------------------------------------------------------------------------
19 // Check if host/device sync is needed
20 //------------------------------------------------------------------------------
21 static inline int CeedVectorNeedSync_Hip(const CeedVector vec, CeedMemType mem_type, bool *need_sync) {
22   CeedVector_Hip *impl;
23   bool            has_valid_array = false;
24 
25   CeedCallBackend(CeedVectorGetData(vec, &impl));
26   CeedCallBackend(CeedVectorHasValidArray(vec, &has_valid_array));
27   switch (mem_type) {
28     case CEED_MEM_HOST:
29       *need_sync = has_valid_array && !impl->h_array;
30       break;
31     case CEED_MEM_DEVICE:
32       *need_sync = has_valid_array && !impl->d_array;
33       break;
34   }
35   return CEED_ERROR_SUCCESS;
36 }
37 
38 //------------------------------------------------------------------------------
39 // Sync host to device
40 //------------------------------------------------------------------------------
41 static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) {
42   Ceed            ceed;
43   CeedSize        length;
44   size_t          bytes;
45   CeedVector_Hip *impl;
46 
47   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
48   CeedCallBackend(CeedVectorGetData(vec, &impl));
49 
50   CeedCheck(impl->h_array, ceed, CEED_ERROR_BACKEND, "No valid host data to sync to device");
51 
52   CeedCallBackend(CeedVectorGetLength(vec, &length));
53   bytes = length * sizeof(CeedScalar);
54   if (impl->d_array_borrowed) {
55     impl->d_array = impl->d_array_borrowed;
56   } else if (impl->d_array_owned) {
57     impl->d_array = impl->d_array_owned;
58   } else {
59     CeedCallHip(ceed, hipMalloc((void **)&impl->d_array_owned, bytes));
60     impl->d_array = impl->d_array_owned;
61   }
62   CeedCallHip(ceed, hipMemcpy(impl->d_array, impl->h_array, bytes, hipMemcpyHostToDevice));
63   return CEED_ERROR_SUCCESS;
64 }
65 
66 //------------------------------------------------------------------------------
67 // Sync device to host
68 //------------------------------------------------------------------------------
69 static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
70   Ceed            ceed;
71   CeedSize        length;
72   size_t          bytes;
73   CeedVector_Hip *impl;
74 
75   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
76   CeedCallBackend(CeedVectorGetData(vec, &impl));
77 
78   CeedCheck(impl->d_array, ceed, CEED_ERROR_BACKEND, "No valid device data to sync to host");
79 
80   if (impl->h_array_borrowed) {
81     impl->h_array = impl->h_array_borrowed;
82   } else if (impl->h_array_owned) {
83     impl->h_array = impl->h_array_owned;
84   } else {
85     CeedSize length;
86 
87     CeedCallBackend(CeedVectorGetLength(vec, &length));
88     CeedCallBackend(CeedCalloc(length, &impl->h_array_owned));
89     impl->h_array = impl->h_array_owned;
90   }
91 
92   CeedCallBackend(CeedVectorGetLength(vec, &length));
93   bytes = length * sizeof(CeedScalar);
94   CeedCallHip(ceed, hipMemcpy(impl->h_array, impl->d_array, bytes, hipMemcpyDeviceToHost));
95   return CEED_ERROR_SUCCESS;
96 }
97 
98 //------------------------------------------------------------------------------
99 // Sync arrays
100 //------------------------------------------------------------------------------
101 static int CeedVectorSyncArray_Hip(const CeedVector vec, CeedMemType mem_type) {
102   bool need_sync = false;
103 
104   // Check whether device/host sync is needed
105   CeedCallBackend(CeedVectorNeedSync_Hip(vec, mem_type, &need_sync));
106   if (!need_sync) return CEED_ERROR_SUCCESS;
107 
108   switch (mem_type) {
109     case CEED_MEM_HOST:
110       return CeedVectorSyncD2H_Hip(vec);
111     case CEED_MEM_DEVICE:
112       return CeedVectorSyncH2D_Hip(vec);
113   }
114   return CEED_ERROR_UNSUPPORTED;
115 }
116 
117 //------------------------------------------------------------------------------
118 // Set all pointers as invalid
119 //------------------------------------------------------------------------------
120 static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) {
121   CeedVector_Hip *impl;
122 
123   CeedCallBackend(CeedVectorGetData(vec, &impl));
124   impl->h_array = NULL;
125   impl->d_array = NULL;
126   return CEED_ERROR_SUCCESS;
127 }
128 
129 //------------------------------------------------------------------------------
130 // Check if CeedVector has any valid pointer
131 //------------------------------------------------------------------------------
132 static inline int CeedVectorHasValidArray_Hip(const CeedVector vec, bool *has_valid_array) {
133   CeedVector_Hip *impl;
134 
135   CeedCallBackend(CeedVectorGetData(vec, &impl));
136   *has_valid_array = impl->h_array || impl->d_array;
137   return CEED_ERROR_SUCCESS;
138 }
139 
140 //------------------------------------------------------------------------------
141 // Check if has array of given type
142 //------------------------------------------------------------------------------
143 static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_array_of_type) {
144   CeedVector_Hip *impl;
145 
146   CeedCallBackend(CeedVectorGetData(vec, &impl));
147   switch (mem_type) {
148     case CEED_MEM_HOST:
149       *has_array_of_type = impl->h_array_borrowed || impl->h_array_owned;
150       break;
151     case CEED_MEM_DEVICE:
152       *has_array_of_type = impl->d_array_borrowed || impl->d_array_owned;
153       break;
154   }
155   return CEED_ERROR_SUCCESS;
156 }
157 
158 //------------------------------------------------------------------------------
159 // Check if has borrowed array of given type
160 //------------------------------------------------------------------------------
161 static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec, CeedMemType mem_type, bool *has_borrowed_array_of_type) {
162   CeedVector_Hip *impl;
163 
164   CeedCallBackend(CeedVectorGetData(vec, &impl));
165   switch (mem_type) {
166     case CEED_MEM_HOST:
167       *has_borrowed_array_of_type = impl->h_array_borrowed;
168       break;
169     case CEED_MEM_DEVICE:
170       *has_borrowed_array_of_type = impl->d_array_borrowed;
171       break;
172   }
173   return CEED_ERROR_SUCCESS;
174 }
175 
176 //------------------------------------------------------------------------------
177 // Set array from host
178 //------------------------------------------------------------------------------
179 static int CeedVectorSetArrayHost_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
180   CeedSize        length;
181   CeedVector_Hip *impl;
182 
183   CeedCallBackend(CeedVectorGetData(vec, &impl));
184   CeedCallBackend(CeedVectorGetLength(vec, &length));
185 
186   CeedCallBackend(CeedSetHostCeedScalarArray(array, copy_mode, length, (const CeedScalar **)&impl->h_array_owned,
187                                              (const CeedScalar **)&impl->h_array_borrowed, (const CeedScalar **)&impl->h_array));
188   return CEED_ERROR_SUCCESS;
189 }
190 
191 //------------------------------------------------------------------------------
192 // Set array from device
193 //------------------------------------------------------------------------------
194 static int CeedVectorSetArrayDevice_Hip(const CeedVector vec, const CeedCopyMode copy_mode, CeedScalar *array) {
195   CeedSize        length;
196   Ceed            ceed;
197   CeedVector_Hip *impl;
198 
199   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
200   CeedCallBackend(CeedVectorGetData(vec, &impl));
201   CeedCallBackend(CeedVectorGetLength(vec, &length));
202 
203   CeedCallBackend(CeedSetDeviceCeedScalarArray_Hip(ceed, array, copy_mode, length, (const CeedScalar **)&impl->d_array_owned,
204                                                    (const CeedScalar **)&impl->d_array_borrowed, (const CeedScalar **)&impl->d_array));
205   return CEED_ERROR_SUCCESS;
206 }
207 
208 //------------------------------------------------------------------------------
209 // Set the array used by a vector,
210 //   freeing any previously allocated array if applicable
211 //------------------------------------------------------------------------------
212 static int CeedVectorSetArray_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedCopyMode copy_mode, CeedScalar *array) {
213   CeedVector_Hip *impl;
214 
215   CeedCallBackend(CeedVectorGetData(vec, &impl));
216   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
217   switch (mem_type) {
218     case CEED_MEM_HOST:
219       return CeedVectorSetArrayHost_Hip(vec, copy_mode, array);
220     case CEED_MEM_DEVICE:
221       return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array);
222   }
223   return CEED_ERROR_UNSUPPORTED;
224 }
225 
226 //------------------------------------------------------------------------------
227 // Set host array to value
228 //------------------------------------------------------------------------------
229 static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedSize length, CeedScalar val) {
230   for (CeedSize i = 0; i < length; i++) h_array[i] = val;
231   return CEED_ERROR_SUCCESS;
232 }
233 
234 //------------------------------------------------------------------------------
235 // Set device array to value (impl in .hip file)
236 //------------------------------------------------------------------------------
237 int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedSize length, CeedScalar val);
238 
239 //------------------------------------------------------------------------------
240 // Set a vector to a value
241 //------------------------------------------------------------------------------
242 static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
243   CeedSize        length;
244   CeedVector_Hip *impl;
245 
246   CeedCallBackend(CeedVectorGetData(vec, &impl));
247   CeedCallBackend(CeedVectorGetLength(vec, &length));
248   // Set value for synced device/host array
249   if (!impl->d_array && !impl->h_array) {
250     if (impl->d_array_borrowed) {
251       impl->d_array = impl->d_array_borrowed;
252     } else if (impl->h_array_borrowed) {
253       impl->h_array = impl->h_array_borrowed;
254     } else if (impl->d_array_owned) {
255       impl->d_array = impl->d_array_owned;
256     } else if (impl->h_array_owned) {
257       impl->h_array = impl->h_array_owned;
258     } else {
259       CeedCallBackend(CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL));
260     }
261   }
262   if (impl->d_array) {
263     CeedCallBackend(CeedDeviceSetValue_Hip(impl->d_array, length, val));
264     impl->h_array = NULL;
265   }
266   if (impl->h_array) {
267     CeedCallBackend(CeedHostSetValue_Hip(impl->h_array, length, val));
268     impl->d_array = NULL;
269   }
270   return CEED_ERROR_SUCCESS;
271 }
272 
273 //------------------------------------------------------------------------------
274 // Vector Take Array
275 //------------------------------------------------------------------------------
276 static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type, CeedScalar **array) {
277   CeedVector_Hip *impl;
278 
279   CeedCallBackend(CeedVectorGetData(vec, &impl));
280 
281   // Sync array to requested mem_type
282   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
283 
284   // Update pointer
285   switch (mem_type) {
286     case CEED_MEM_HOST:
287       (*array)               = impl->h_array_borrowed;
288       impl->h_array_borrowed = NULL;
289       impl->h_array          = NULL;
290       break;
291     case CEED_MEM_DEVICE:
292       (*array)               = impl->d_array_borrowed;
293       impl->d_array_borrowed = NULL;
294       impl->d_array          = NULL;
295       break;
296   }
297   return CEED_ERROR_SUCCESS;
298 }
299 
300 //------------------------------------------------------------------------------
301 // Core logic for array syncronization for GetArray.
302 //   If a different memory type is most up to date, this will perform a copy
303 //------------------------------------------------------------------------------
304 static int CeedVectorGetArrayCore_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
305   CeedVector_Hip *impl;
306 
307   CeedCallBackend(CeedVectorGetData(vec, &impl));
308 
309   // Sync array to requested mem_type
310   CeedCallBackend(CeedVectorSyncArray(vec, mem_type));
311 
312   // Update pointer
313   switch (mem_type) {
314     case CEED_MEM_HOST:
315       *array = impl->h_array;
316       break;
317     case CEED_MEM_DEVICE:
318       *array = impl->d_array;
319       break;
320   }
321   return CEED_ERROR_SUCCESS;
322 }
323 
324 //------------------------------------------------------------------------------
325 // Get read-only access to a vector via the specified mem_type
326 //------------------------------------------------------------------------------
327 static int CeedVectorGetArrayRead_Hip(const CeedVector vec, const CeedMemType mem_type, const CeedScalar **array) {
328   return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array);
329 }
330 
331 //------------------------------------------------------------------------------
332 // Get read/write access to a vector via the specified mem_type
333 //------------------------------------------------------------------------------
334 static int CeedVectorGetArray_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
335   CeedVector_Hip *impl;
336 
337   CeedCallBackend(CeedVectorGetData(vec, &impl));
338   CeedCallBackend(CeedVectorGetArrayCore_Hip(vec, mem_type, array));
339   CeedCallBackend(CeedVectorSetAllInvalid_Hip(vec));
340   switch (mem_type) {
341     case CEED_MEM_HOST:
342       impl->h_array = *array;
343       break;
344     case CEED_MEM_DEVICE:
345       impl->d_array = *array;
346       break;
347   }
348   return CEED_ERROR_SUCCESS;
349 }
350 
351 //------------------------------------------------------------------------------
352 // Get write access to a vector via the specified mem_type
353 //------------------------------------------------------------------------------
354 static int CeedVectorGetArrayWrite_Hip(const CeedVector vec, const CeedMemType mem_type, CeedScalar **array) {
355   bool            has_array_of_type = true;
356   CeedVector_Hip *impl;
357 
358   CeedCallBackend(CeedVectorGetData(vec, &impl));
359   CeedCallBackend(CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type));
360   if (!has_array_of_type) {
361     // Allocate if array is not yet allocated
362     CeedCallBackend(CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL));
363   } else {
364     // Select dirty array
365     switch (mem_type) {
366       case CEED_MEM_HOST:
367         if (impl->h_array_borrowed) impl->h_array = impl->h_array_borrowed;
368         else impl->h_array = impl->h_array_owned;
369         break;
370       case CEED_MEM_DEVICE:
371         if (impl->d_array_borrowed) impl->d_array = impl->d_array_borrowed;
372         else impl->d_array = impl->d_array_owned;
373     }
374   }
375   return CeedVectorGetArray_Hip(vec, mem_type, array);
376 }
377 
378 //------------------------------------------------------------------------------
379 // Get the norm of a CeedVector
380 //------------------------------------------------------------------------------
381 static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type, CeedScalar *norm) {
382   Ceed              ceed;
383   CeedSize          length, num_calls;
384   const CeedScalar *d_array;
385   CeedVector_Hip   *impl;
386   hipblasHandle_t   handle;
387 
388   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
389   CeedCallBackend(CeedVectorGetData(vec, &impl));
390   CeedCallBackend(CeedVectorGetLength(vec, &length));
391   CeedCallBackend(CeedGetHipblasHandle_Hip(ceed, &handle));
392 
393   // Is the vector too long to handle with int32? If so, we will divide
394   // it up into "int32-sized" subsections and make repeated BLAS calls.
395   num_calls = length / INT_MAX;
396   if (length % INT_MAX > 0) num_calls += 1;
397 
398   // Compute norm
399   CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array));
400   switch (type) {
401     case CEED_NORM_1: {
402       *norm = 0.0;
403       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
404         float  sub_norm = 0.0;
405         float *d_array_start;
406 
407         for (CeedInt i = 0; i < num_calls; i++) {
408           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
409           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
410           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
411 
412           CeedCallHipblas(ceed, hipblasSasum(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
413           *norm += sub_norm;
414         }
415       } else {
416         double  sub_norm = 0.0;
417         double *d_array_start;
418 
419         for (CeedInt i = 0; i < num_calls; i++) {
420           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
421           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
422           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
423 
424           CeedCallHipblas(ceed, hipblasDasum(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
425           *norm += sub_norm;
426         }
427       }
428       break;
429     }
430     case CEED_NORM_2: {
431       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
432         float  sub_norm = 0.0, norm_sum = 0.0;
433         float *d_array_start;
434 
435         for (CeedInt i = 0; i < num_calls; i++) {
436           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
437           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
438           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
439 
440           CeedCallHipblas(ceed, hipblasSnrm2(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &sub_norm));
441           norm_sum += sub_norm * sub_norm;
442         }
443         *norm = sqrt(norm_sum);
444       } else {
445         double  sub_norm = 0.0, norm_sum = 0.0;
446         double *d_array_start;
447 
448         for (CeedInt i = 0; i < num_calls; i++) {
449           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
450           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
451           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
452 
453           CeedCallHipblas(ceed, hipblasDnrm2(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &sub_norm));
454           norm_sum += sub_norm * sub_norm;
455         }
456         *norm = sqrt(norm_sum);
457       }
458       break;
459     }
460     case CEED_NORM_MAX: {
461       CeedInt index;
462 
463       if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
464         float  sub_max = 0.0, current_max = 0.0;
465         float *d_array_start;
466         for (CeedInt i = 0; i < num_calls; i++) {
467           d_array_start             = (float *)d_array + (CeedSize)(i)*INT_MAX;
468           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
469           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
470 
471           CeedCallHipblas(ceed, hipblasIsamax(handle, (CeedInt)sub_length, (float *)d_array_start, 1, &index));
472           CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
473           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
474         }
475         *norm = current_max;
476       } else {
477         double  sub_max = 0.0, current_max = 0.0;
478         double *d_array_start;
479 
480         for (CeedInt i = 0; i < num_calls; i++) {
481           d_array_start             = (double *)d_array + (CeedSize)(i)*INT_MAX;
482           CeedSize remaining_length = length - (CeedSize)(i)*INT_MAX;
483           CeedInt  sub_length       = (i == num_calls - 1) ? (CeedInt)(remaining_length) : INT_MAX;
484 
485           CeedCallHipblas(ceed, hipblasIdamax(handle, (CeedInt)sub_length, (double *)d_array_start, 1, &index));
486           CeedCallHip(ceed, hipMemcpy(&sub_max, d_array_start + index - 1, sizeof(CeedScalar), hipMemcpyDeviceToHost));
487           if (fabs(sub_max) > current_max) current_max = fabs(sub_max);
488         }
489         *norm = current_max;
490       }
491       break;
492     }
493   }
494   CeedCallBackend(CeedVectorRestoreArrayRead(vec, &d_array));
495   return CEED_ERROR_SUCCESS;
496 }
497 
498 //------------------------------------------------------------------------------
499 // Take reciprocal of a vector on host
500 //------------------------------------------------------------------------------
501 static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedSize length) {
502   for (CeedSize i = 0; i < length; i++) {
503     if (fabs(h_array[i]) > CEED_EPSILON) h_array[i] = 1. / h_array[i];
504   }
505   return CEED_ERROR_SUCCESS;
506 }
507 
508 //------------------------------------------------------------------------------
509 // Take reciprocal of a vector on device (impl in .cu file)
510 //------------------------------------------------------------------------------
511 int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedSize length);
512 
513 //------------------------------------------------------------------------------
514 // Take reciprocal of a vector
515 //------------------------------------------------------------------------------
516 static int CeedVectorReciprocal_Hip(CeedVector vec) {
517   CeedSize        length;
518   CeedVector_Hip *impl;
519 
520   CeedCallBackend(CeedVectorGetData(vec, &impl));
521   CeedCallBackend(CeedVectorGetLength(vec, &length));
522   // Set value for synced device/host array
523   if (impl->d_array) CeedCallBackend(CeedDeviceReciprocal_Hip(impl->d_array, length));
524   if (impl->h_array) CeedCallBackend(CeedHostReciprocal_Hip(impl->h_array, length));
525   return CEED_ERROR_SUCCESS;
526 }
527 
528 //------------------------------------------------------------------------------
529 // Compute x = alpha x on the host
530 //------------------------------------------------------------------------------
531 static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length) {
532   for (CeedSize i = 0; i < length; i++) x_array[i] *= alpha;
533   return CEED_ERROR_SUCCESS;
534 }
535 
536 //------------------------------------------------------------------------------
537 // Compute x = alpha x on device (impl in .cu file)
538 //------------------------------------------------------------------------------
539 int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha, CeedSize length);
540 
541 //------------------------------------------------------------------------------
542 // Compute x = alpha x
543 //------------------------------------------------------------------------------
544 static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
545   CeedSize        length;
546   CeedVector_Hip *x_impl;
547 
548   CeedCallBackend(CeedVectorGetData(x, &x_impl));
549   CeedCallBackend(CeedVectorGetLength(x, &length));
550   // Set value for synced device/host array
551   if (x_impl->d_array) CeedCallBackend(CeedDeviceScale_Hip(x_impl->d_array, alpha, length));
552   if (x_impl->h_array) CeedCallBackend(CeedHostScale_Hip(x_impl->h_array, alpha, length));
553   return CEED_ERROR_SUCCESS;
554 }
555 
556 //------------------------------------------------------------------------------
557 // Compute y = alpha x + y on the host
558 //------------------------------------------------------------------------------
559 static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length) {
560   for (CeedSize i = 0; i < length; i++) y_array[i] += alpha * x_array[i];
561   return CEED_ERROR_SUCCESS;
562 }
563 
564 //------------------------------------------------------------------------------
565 // Compute y = alpha x + y on device (impl in .cu file)
566 //------------------------------------------------------------------------------
567 int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar *x_array, CeedSize length);
568 
569 //------------------------------------------------------------------------------
570 // Compute y = alpha x + y
571 //------------------------------------------------------------------------------
572 static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
573   CeedSize        length;
574   CeedVector_Hip *y_impl, *x_impl;
575 
576   CeedCallBackend(CeedVectorGetData(y, &y_impl));
577   CeedCallBackend(CeedVectorGetData(x, &x_impl));
578   CeedCallBackend(CeedVectorGetLength(y, &length));
579   // Set value for synced device/host array
580   if (y_impl->d_array) {
581     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
582     CeedCallBackend(CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length));
583   }
584   if (y_impl->h_array) {
585     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
586     CeedCallBackend(CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length));
587   }
588   return CEED_ERROR_SUCCESS;
589 }
590 
591 //------------------------------------------------------------------------------
592 // Compute y = alpha x + beta y on the host
593 //------------------------------------------------------------------------------
594 static int CeedHostAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length) {
595   for (CeedSize i = 0; i < length; i++) y_array[i] = alpha * x_array[i] + beta * y_array[i];
596   return CEED_ERROR_SUCCESS;
597 }
598 
599 //------------------------------------------------------------------------------
600 // Compute y = alpha x + beta y on device (impl in .cu file)
601 //------------------------------------------------------------------------------
602 int CeedDeviceAXPBY_Hip(CeedScalar *y_array, CeedScalar alpha, CeedScalar beta, CeedScalar *x_array, CeedSize length);
603 
604 //------------------------------------------------------------------------------
605 // Compute y = alpha x + beta y
606 //------------------------------------------------------------------------------
607 static int CeedVectorAXPBY_Hip(CeedVector y, CeedScalar alpha, CeedScalar beta, CeedVector x) {
608   CeedSize        length;
609   CeedVector_Hip *y_impl, *x_impl;
610 
611   CeedCallBackend(CeedVectorGetData(y, &y_impl));
612   CeedCallBackend(CeedVectorGetData(x, &x_impl));
613   CeedCallBackend(CeedVectorGetLength(y, &length));
614   // Set value for synced device/host array
615   if (y_impl->d_array) {
616     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
617     CeedCallBackend(CeedDeviceAXPBY_Hip(y_impl->d_array, alpha, beta, x_impl->d_array, length));
618   }
619   if (y_impl->h_array) {
620     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
621     CeedCallBackend(CeedHostAXPBY_Hip(y_impl->h_array, alpha, beta, x_impl->h_array, length));
622   }
623   return CEED_ERROR_SUCCESS;
624 }
625 
626 //------------------------------------------------------------------------------
627 // Compute the pointwise multiplication w = x .* y on the host
628 //------------------------------------------------------------------------------
629 static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length) {
630   for (CeedSize i = 0; i < length; i++) w_array[i] = x_array[i] * y_array[i];
631   return CEED_ERROR_SUCCESS;
632 }
633 
634 //------------------------------------------------------------------------------
635 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
636 //------------------------------------------------------------------------------
637 int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array, CeedScalar *y_array, CeedSize length);
638 
639 //------------------------------------------------------------------------------
640 // Compute the pointwise multiplication w = x .* y
641 //------------------------------------------------------------------------------
642 static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x, CeedVector y) {
643   CeedSize        length;
644   CeedVector_Hip *w_impl, *x_impl, *y_impl;
645 
646   CeedCallBackend(CeedVectorGetData(w, &w_impl));
647   CeedCallBackend(CeedVectorGetData(x, &x_impl));
648   CeedCallBackend(CeedVectorGetData(y, &y_impl));
649   CeedCallBackend(CeedVectorGetLength(w, &length));
650 
651   // Set value for synced device/host array
652   if (!w_impl->d_array && !w_impl->h_array) {
653     CeedCallBackend(CeedVectorSetValue(w, 0.0));
654   }
655   if (w_impl->d_array) {
656     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_DEVICE));
657     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_DEVICE));
658     CeedCallBackend(CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array, y_impl->d_array, length));
659   }
660   if (w_impl->h_array) {
661     CeedCallBackend(CeedVectorSyncArray(x, CEED_MEM_HOST));
662     CeedCallBackend(CeedVectorSyncArray(y, CEED_MEM_HOST));
663     CeedCallBackend(CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array, y_impl->h_array, length));
664   }
665   return CEED_ERROR_SUCCESS;
666 }
667 
668 //------------------------------------------------------------------------------
669 // Destroy the vector
670 //------------------------------------------------------------------------------
671 static int CeedVectorDestroy_Hip(const CeedVector vec) {
672   CeedVector_Hip *impl;
673 
674   CeedCallBackend(CeedVectorGetData(vec, &impl));
675   CeedCallHip(CeedVectorReturnCeed(vec), hipFree(impl->d_array_owned));
676   CeedCallBackend(CeedFree(&impl->h_array_owned));
677   CeedCallBackend(CeedFree(&impl));
678   return CEED_ERROR_SUCCESS;
679 }
680 
681 //------------------------------------------------------------------------------
682 // Create a vector of the specified length (does not allocate memory)
683 //------------------------------------------------------------------------------
684 int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
685   CeedVector_Hip *impl;
686   Ceed            ceed;
687 
688   CeedCallBackend(CeedVectorGetCeed(vec, &ceed));
689   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray", CeedVectorHasValidArray_Hip));
690   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType", CeedVectorHasBorrowedArrayOfType_Hip));
691   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetArray", CeedVectorSetArray_Hip));
692   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray", CeedVectorTakeArray_Hip));
693   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SetValue", (int (*)())CeedVectorSetValue_Hip));
694   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray", CeedVectorSyncArray_Hip));
695   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArray", CeedVectorGetArray_Hip));
696   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead", CeedVectorGetArrayRead_Hip));
697   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite", CeedVectorGetArrayWrite_Hip));
698   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Norm", CeedVectorNorm_Hip));
699   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal", CeedVectorReciprocal_Hip));
700   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Scale", (int (*)())CeedVectorScale_Hip));
701   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPY", (int (*)())CeedVectorAXPY_Hip));
702   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "AXPBY", (int (*)())CeedVectorAXPBY_Hip));
703   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult", CeedVectorPointwiseMult_Hip));
704   CeedCallBackend(CeedSetBackendFunction(ceed, "Vector", vec, "Destroy", CeedVectorDestroy_Hip));
705   CeedCallBackend(CeedCalloc(1, &impl));
706   CeedCallBackend(CeedVectorSetData(vec, impl));
707   return CEED_ERROR_SUCCESS;
708 }
709 
710 //------------------------------------------------------------------------------
711