xref: /libCEED/backends/hip-ref/ceed-hip-ref-vector.c (revision ba6664ae303f5b2ef46b3df96973d9bdc665107c)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/ceed.h>
9 #include <ceed/backend.h>
10 #include <hip/hip_runtime.h>
11 #include <hipblas.h>
12 #include <math.h>
13 #include <string.h>
14 #include "ceed-hip-ref.h"
15 
16 
17 //------------------------------------------------------------------------------
18 // Check if host/device sync is needed
19 //------------------------------------------------------------------------------
20 static inline int CeedVectorNeedSync_Hip(const CeedVector vec,
21     CeedMemType mem_type, bool *need_sync) {
22   int ierr;
23   CeedVector_Hip *impl;
24   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
25 
26   bool has_valid_array = false;
27   ierr = CeedVectorHasValidArray(vec, &has_valid_array); CeedChkBackend(ierr);
28   switch (mem_type) {
29   case CEED_MEM_HOST:
30     *need_sync = has_valid_array && !impl->h_array;
31     break;
32   case CEED_MEM_DEVICE:
33     *need_sync = has_valid_array && !impl->d_array;
34     break;
35   }
36 
37   return CEED_ERROR_SUCCESS;
38 }
39 
40 //------------------------------------------------------------------------------
41 // Sync host to device
42 //------------------------------------------------------------------------------
43 static inline int CeedVectorSyncH2D_Hip(const CeedVector vec) {
44   int ierr;
45   Ceed ceed;
46   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
47   CeedVector_Hip *impl;
48   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
49 
50   CeedSize length;
51   ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
52   size_t bytes = length * sizeof(CeedScalar);
53 
54   if (!impl->h_array)
55     // LCOV_EXCL_START
56     return CeedError(ceed, CEED_ERROR_BACKEND,
57                      "No valid host data to sync to device");
58   // LCOV_EXCL_STOP
59 
60   if (impl->d_array_borrowed) {
61     impl->d_array = impl->d_array_borrowed;
62   } else if (impl->d_array_owned) {
63     impl->d_array = impl->d_array_owned;
64   } else {
65     ierr = hipMalloc((void **)&impl->d_array_owned, bytes);
66     CeedChk_Hip(ceed, ierr);
67     impl->d_array = impl->d_array_owned;
68   }
69 
70   ierr = hipMemcpy(impl->d_array, impl->h_array, bytes,
71                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
72 
73   return CEED_ERROR_SUCCESS;
74 }
75 
76 //------------------------------------------------------------------------------
77 // Sync device to host
78 //------------------------------------------------------------------------------
79 static inline int CeedVectorSyncD2H_Hip(const CeedVector vec) {
80   int ierr;
81   Ceed ceed;
82   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
83   CeedVector_Hip *impl;
84   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
85 
86   if (!impl->d_array)
87     // LCOV_EXCL_START
88     return CeedError(ceed, CEED_ERROR_BACKEND,
89                      "No valid device data to sync to host");
90   // LCOV_EXCL_STOP
91 
92   if (impl->h_array_borrowed) {
93     impl->h_array = impl->h_array_borrowed;
94   } else if (impl->h_array_owned) {
95     impl->h_array = impl->h_array_owned;
96   } else {
97     CeedSize length;
98     ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
99     ierr = CeedCalloc(length, &impl->h_array_owned); CeedChkBackend(ierr);
100     impl->h_array = impl->h_array_owned;
101   }
102 
103   CeedSize length;
104   ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
105   size_t bytes = length * sizeof(CeedScalar);
106   ierr = hipMemcpy(impl->h_array, impl->d_array, bytes,
107                    hipMemcpyDeviceToHost); CeedChk_Hip(ceed, ierr);
108 
109   return CEED_ERROR_SUCCESS;
110 }
111 
112 //------------------------------------------------------------------------------
113 // Sync arrays
114 //------------------------------------------------------------------------------
115 static int CeedVectorSyncArray_Hip(const CeedVector vec,
116                                    CeedMemType mem_type) {
117   int ierr;
118   // Check whether device/host sync is needed
119   bool need_sync = false;
120   ierr = CeedVectorNeedSync_Hip(vec, mem_type, &need_sync);
121   CeedChkBackend(ierr);
122   if (!need_sync)
123     return CEED_ERROR_SUCCESS;
124 
125   switch (mem_type) {
126   case CEED_MEM_HOST: return CeedVectorSyncD2H_Hip(vec);
127   case CEED_MEM_DEVICE: return CeedVectorSyncH2D_Hip(vec);
128   }
129   return CEED_ERROR_UNSUPPORTED;
130 }
131 
132 //------------------------------------------------------------------------------
133 // Set all pointers as invalid
134 //------------------------------------------------------------------------------
135 static inline int CeedVectorSetAllInvalid_Hip(const CeedVector vec) {
136   int ierr;
137   CeedVector_Hip *impl;
138   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
139 
140   impl->h_array = NULL;
141   impl->d_array = NULL;
142 
143   return CEED_ERROR_SUCCESS;
144 }
145 
146 //------------------------------------------------------------------------------
147 // Check if CeedVector has any valid pointers
148 //------------------------------------------------------------------------------
149 static inline int CeedVectorHasValidArray_Hip(const CeedVector vec,
150     bool *has_valid_array) {
151   int ierr;
152   CeedVector_Hip *impl;
153   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
154 
155   *has_valid_array = !!impl->h_array || !!impl->d_array;
156 
157   return CEED_ERROR_SUCCESS;
158 }
159 
160 //------------------------------------------------------------------------------
161 // Check if has any array of given type
162 //------------------------------------------------------------------------------
163 static inline int CeedVectorHasArrayOfType_Hip(const CeedVector vec,
164     CeedMemType mem_type, bool *has_array_of_type) {
165   int ierr;
166   CeedVector_Hip *impl;
167   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
168 
169   switch (mem_type) {
170   case CEED_MEM_HOST:
171     *has_array_of_type = !!impl->h_array_borrowed || !!impl->h_array_owned;
172     break;
173   case CEED_MEM_DEVICE:
174     *has_array_of_type = !!impl->d_array_borrowed || !!impl->d_array_owned;
175     break;
176   }
177 
178   return CEED_ERROR_SUCCESS;
179 }
180 
181 //------------------------------------------------------------------------------
182 // Check if has borrowed array of given type
183 //------------------------------------------------------------------------------
184 static inline int CeedVectorHasBorrowedArrayOfType_Hip(const CeedVector vec,
185     CeedMemType mem_type, bool *has_borrowed_array_of_type) {
186   int ierr;
187   CeedVector_Hip *impl;
188   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
189 
190   switch (mem_type) {
191   case CEED_MEM_HOST:
192     *has_borrowed_array_of_type = !!impl->h_array_borrowed;
193     break;
194   case CEED_MEM_DEVICE:
195     *has_borrowed_array_of_type = !!impl->d_array_borrowed;
196     break;
197   }
198 
199   return CEED_ERROR_SUCCESS;
200 }
201 
202 //------------------------------------------------------------------------------
203 // Set array from host
204 //------------------------------------------------------------------------------
205 static int CeedVectorSetArrayHost_Hip(const CeedVector vec,
206                                       const CeedCopyMode copy_mode, CeedScalar *array) {
207   int ierr;
208   CeedVector_Hip *impl;
209   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
210 
211   switch (copy_mode) {
212   case CEED_COPY_VALUES: {
213     CeedSize length;
214     if (!impl->h_array_owned) {
215       ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
216       ierr = CeedMalloc(length, &impl->h_array_owned); CeedChkBackend(ierr);
217     }
218     impl->h_array_borrowed = NULL;
219     impl->h_array = impl->h_array_owned;
220     if (array) {
221       CeedSize length;
222       ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
223       size_t bytes = length * sizeof(CeedScalar);
224       memcpy(impl->h_array, array, bytes);
225     }
226   } break;
227   case CEED_OWN_POINTER:
228     ierr = CeedFree(&impl->h_array_owned); CeedChkBackend(ierr);
229     impl->h_array_owned = array;
230     impl->h_array_borrowed = NULL;
231     impl->h_array = array;
232     break;
233   case CEED_USE_POINTER:
234     ierr = CeedFree(&impl->h_array_owned); CeedChkBackend(ierr);
235     impl->h_array_borrowed = array;
236     impl->h_array = array;
237     break;
238   }
239 
240   return CEED_ERROR_SUCCESS;
241 }
242 
243 //------------------------------------------------------------------------------
244 // Set array from device
245 //------------------------------------------------------------------------------
246 static int CeedVectorSetArrayDevice_Hip(const CeedVector vec,
247                                         const CeedCopyMode copy_mode, CeedScalar *array) {
248   int ierr;
249   Ceed ceed;
250   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
251   CeedVector_Hip *impl;
252   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
253 
254   switch (copy_mode) {
255   case CEED_COPY_VALUES: {
256     CeedSize length;
257     ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
258     size_t bytes = length * sizeof(CeedScalar);
259     if (!impl->d_array_owned) {
260       ierr = hipMalloc((void **)&impl->d_array_owned, bytes);
261       CeedChk_Hip(ceed, ierr);
262     }
263     impl->d_array_borrowed = NULL;
264     impl->d_array = impl->d_array_owned;
265     if (array) {
266       ierr = hipMemcpy(impl->d_array, array, bytes,
267                        hipMemcpyDeviceToDevice); CeedChk_Hip(ceed, ierr);
268     }
269   } break;
270   case CEED_OWN_POINTER:
271     ierr = hipFree(impl->d_array_owned); CeedChk_Hip(ceed, ierr);
272     impl->d_array_owned = array;
273     impl->d_array_borrowed = NULL;
274     impl->d_array = array;
275     break;
276   case CEED_USE_POINTER:
277     ierr = hipFree(impl->d_array_owned); CeedChk_Hip(ceed, ierr);
278     impl->d_array_owned = NULL;
279     impl->d_array_borrowed = array;
280     impl->d_array = array;
281     break;
282   }
283 
284   return CEED_ERROR_SUCCESS;
285 }
286 
287 //------------------------------------------------------------------------------
288 // Set the array used by a vector,
289 //   freeing any previously allocated array if applicable
290 //------------------------------------------------------------------------------
291 static int CeedVectorSetArray_Hip(const CeedVector vec,
292                                   const CeedMemType mem_type,
293                                   const CeedCopyMode copy_mode, CeedScalar *array) {
294   int ierr;
295   Ceed ceed;
296   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
297   CeedVector_Hip *impl;
298   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
299 
300   ierr = CeedVectorSetAllInvalid_Hip(vec); CeedChkBackend(ierr);
301   switch (mem_type) {
302   case CEED_MEM_HOST:
303     return CeedVectorSetArrayHost_Hip(vec, copy_mode, array);
304   case CEED_MEM_DEVICE:
305     return CeedVectorSetArrayDevice_Hip(vec, copy_mode, array);
306   }
307 
308   return CEED_ERROR_UNSUPPORTED;
309 }
310 
311 //------------------------------------------------------------------------------
312 // Set host array to value
313 //------------------------------------------------------------------------------
314 static int CeedHostSetValue_Hip(CeedScalar *h_array, CeedInt length,
315                                 CeedScalar val) {
316   for (int i = 0; i < length; i++)
317     h_array[i] = val;
318   return CEED_ERROR_SUCCESS;
319 }
320 
321 //------------------------------------------------------------------------------
322 // Set device array to value (impl in .hip file)
323 //------------------------------------------------------------------------------
324 int CeedDeviceSetValue_Hip(CeedScalar *d_array, CeedInt length, CeedScalar val);
325 
326 //------------------------------------------------------------------------------
327 // Set a vector to a value,
328 //------------------------------------------------------------------------------
329 static int CeedVectorSetValue_Hip(CeedVector vec, CeedScalar val) {
330   int ierr;
331   Ceed ceed;
332   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
333   CeedVector_Hip *impl;
334   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
335   CeedSize length;
336   ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
337 
338   // Set value for synced device/host array
339   if (!impl->d_array && !impl->h_array) {
340     if (impl->d_array_borrowed) {
341       impl->d_array = impl->d_array_borrowed;
342     } else if (impl->h_array_borrowed) {
343       impl->h_array = impl->h_array_borrowed;
344     } else if (impl->d_array_owned) {
345       impl->d_array = impl->d_array_owned;
346     } else if (impl->h_array_owned) {
347       impl->h_array = impl->h_array_owned;
348     } else {
349       ierr = CeedVectorSetArray(vec, CEED_MEM_DEVICE, CEED_COPY_VALUES, NULL);
350       CeedChkBackend(ierr);
351     }
352   }
353   if (impl->d_array) {
354     ierr = CeedDeviceSetValue_Hip(impl->d_array, length, val); CeedChkBackend(ierr);
355   }
356   if (impl->h_array) {
357     ierr = CeedHostSetValue_Hip(impl->h_array, length, val); CeedChkBackend(ierr);
358   }
359 
360   return CEED_ERROR_SUCCESS;
361 }
362 
363 //------------------------------------------------------------------------------
364 // Vector Take Array
365 //------------------------------------------------------------------------------
366 static int CeedVectorTakeArray_Hip(CeedVector vec, CeedMemType mem_type,
367                                    CeedScalar **array) {
368   int ierr;
369   Ceed ceed;
370   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
371   CeedVector_Hip *impl;
372   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
373 
374   // Sync array to requested mem_type
375   ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr);
376 
377   // Update pointer
378   switch (mem_type) {
379   case CEED_MEM_HOST:
380     (*array) = impl->h_array_borrowed;
381     impl->h_array_borrowed = NULL;
382     impl->h_array = NULL;
383     break;
384   case CEED_MEM_DEVICE:
385     (*array) = impl->d_array_borrowed;
386     impl->d_array_borrowed = NULL;
387     impl->d_array = NULL;
388     break;
389   }
390 
391   return CEED_ERROR_SUCCESS;
392 }
393 
394 //------------------------------------------------------------------------------
395 // Core logic for array syncronization for GetArray.
396 //   If a different memory type is most up to date, this will perform a copy
397 //------------------------------------------------------------------------------
398 static int CeedVectorGetArrayCore_Hip(const CeedVector vec,
399                                       const CeedMemType mem_type, CeedScalar **array) {
400   int ierr;
401   Ceed ceed;
402   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
403   CeedVector_Hip *impl;
404   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
405 
406   // Sync array to requested mem_type
407   ierr = CeedVectorSyncArray(vec, mem_type); CeedChkBackend(ierr);
408 
409   // Update pointer
410   switch (mem_type) {
411   case CEED_MEM_HOST:
412     *array = impl->h_array;
413     break;
414   case CEED_MEM_DEVICE:
415     *array = impl->d_array;
416     break;
417   }
418 
419   return CEED_ERROR_SUCCESS;
420 }
421 
422 //------------------------------------------------------------------------------
423 // Get read-only access to a vector via the specified mem_type
424 //------------------------------------------------------------------------------
425 static int CeedVectorGetArrayRead_Hip(const CeedVector vec,
426                                       const CeedMemType mem_type, const CeedScalar **array) {
427   return CeedVectorGetArrayCore_Hip(vec, mem_type, (CeedScalar **)array);
428 }
429 
430 //------------------------------------------------------------------------------
431 // Get read/write access to a vector via the specified mem_type
432 //------------------------------------------------------------------------------
433 static int CeedVectorGetArray_Hip(const CeedVector vec,
434                                   const CeedMemType mem_type,
435                                   CeedScalar **array) {
436   int ierr;
437   CeedVector_Hip *impl;
438   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
439 
440   ierr = CeedVectorGetArrayCore_Hip(vec, mem_type, array); CeedChkBackend(ierr);
441 
442   ierr = CeedVectorSetAllInvalid_Hip(vec); CeedChkBackend(ierr);
443   switch (mem_type) {
444   case CEED_MEM_HOST:
445     impl->h_array = *array;
446     break;
447   case CEED_MEM_DEVICE:
448     impl->d_array = *array;
449     break;
450   }
451 
452   return CEED_ERROR_SUCCESS;
453 }
454 
455 //------------------------------------------------------------------------------
456 // Get write access to a vector via the specified mem_type
457 //------------------------------------------------------------------------------
458 static int CeedVectorGetArrayWrite_Hip(const CeedVector vec,
459                                        const CeedMemType mem_type, CeedScalar **array) {
460   int ierr;
461   CeedVector_Hip *impl;
462   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
463 
464   bool has_array_of_type = true;
465   ierr = CeedVectorHasArrayOfType_Hip(vec, mem_type, &has_array_of_type);
466   CeedChkBackend(ierr);
467   if (!has_array_of_type) {
468     // Allocate if array is not yet allocated
469     ierr = CeedVectorSetArray(vec, mem_type, CEED_COPY_VALUES, NULL);
470     CeedChkBackend(ierr);
471   } else {
472     // Select dirty array
473     switch (mem_type) {
474     case CEED_MEM_HOST:
475       if (impl->h_array_borrowed)
476         impl->h_array = impl->h_array_borrowed;
477       else
478         impl->h_array = impl->h_array_owned;
479       break;
480     case CEED_MEM_DEVICE:
481       if (impl->d_array_borrowed)
482         impl->d_array = impl->d_array_borrowed;
483       else
484         impl->d_array = impl->d_array_owned;
485     }
486   }
487 
488   return CeedVectorGetArray_Hip(vec, mem_type, array);
489 }
490 
491 //------------------------------------------------------------------------------
492 // Get the norm of a CeedVector
493 //------------------------------------------------------------------------------
494 static int CeedVectorNorm_Hip(CeedVector vec, CeedNormType type,
495                               CeedScalar *norm) {
496   int ierr;
497   Ceed ceed;
498   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
499   CeedVector_Hip *impl;
500   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
501   CeedSize length;
502   ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
503   hipblasHandle_t handle;
504   ierr = CeedHipGetHipblasHandle(ceed, &handle); CeedChkBackend(ierr);
505 
506   // Compute norm
507   const CeedScalar *d_array;
508   ierr = CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, &d_array);
509   CeedChkBackend(ierr);
510   switch (type) {
511   case CEED_NORM_1: {
512     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
513       ierr = hipblasSasum(handle, length, (float *) d_array, 1, (float *) norm);
514     } else {
515       ierr = hipblasDasum(handle, length, (double *) d_array, 1, (double *) norm);
516     }
517     CeedChk_Hipblas(ceed, ierr);
518     break;
519   }
520   case CEED_NORM_2: {
521     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
522       ierr = hipblasSnrm2(handle, length, (float *) d_array, 1, (float *) norm);
523     } else {
524       ierr = hipblasDnrm2(handle, length, (double *) d_array, 1, (double *) norm);
525     }
526     CeedChk_Hipblas(ceed, ierr);
527     break;
528   }
529   case CEED_NORM_MAX: {
530     CeedInt indx;
531     if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
532       ierr = hipblasIsamax(handle, length, (float *) d_array, 1, &indx);
533     } else {
534       ierr = hipblasIdamax(handle, length, (double *) d_array, 1, &indx);
535     }
536     CeedChk_Hipblas(ceed, ierr);
537     CeedScalar normNoAbs;
538     ierr = hipMemcpy(&normNoAbs, impl->d_array+indx-1, sizeof(CeedScalar),
539                      hipMemcpyDeviceToHost); CeedChk_Hip(ceed, ierr);
540     *norm = fabs(normNoAbs);
541     break;
542   }
543   }
544   ierr = CeedVectorRestoreArrayRead(vec, &d_array); CeedChkBackend(ierr);
545 
546   return CEED_ERROR_SUCCESS;
547 }
548 
549 //------------------------------------------------------------------------------
550 // Take reciprocal of a vector on host
551 //------------------------------------------------------------------------------
552 static int CeedHostReciprocal_Hip(CeedScalar *h_array, CeedInt length) {
553   for (int i = 0; i < length; i++)
554     if (fabs(h_array[i]) > CEED_EPSILON)
555       h_array[i] = 1./h_array[i];
556   return CEED_ERROR_SUCCESS;
557 }
558 
559 //------------------------------------------------------------------------------
560 // Take reciprocal of a vector on device (impl in .cu file)
561 //------------------------------------------------------------------------------
562 int CeedDeviceReciprocal_Hip(CeedScalar *d_array, CeedInt length);
563 
564 //------------------------------------------------------------------------------
565 // Take reciprocal of a vector
566 //------------------------------------------------------------------------------
567 static int CeedVectorReciprocal_Hip(CeedVector vec) {
568   int ierr;
569   Ceed ceed;
570   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
571   CeedVector_Hip *impl;
572   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
573   CeedSize length;
574   ierr = CeedVectorGetLength(vec, &length); CeedChkBackend(ierr);
575 
576   // Set value for synced device/host array
577   if (impl->d_array) {
578     ierr = CeedDeviceReciprocal_Hip(impl->d_array, length); CeedChkBackend(ierr);
579   }
580   if (impl->h_array) {
581     ierr = CeedHostReciprocal_Hip(impl->h_array, length); CeedChkBackend(ierr);
582   }
583 
584   return CEED_ERROR_SUCCESS;
585 }
586 
587 //------------------------------------------------------------------------------
588 // Compute x = alpha x on the host
589 //------------------------------------------------------------------------------
590 static int CeedHostScale_Hip(CeedScalar *x_array, CeedScalar alpha,
591                              CeedInt length) {
592   for (int i = 0; i < length; i++)
593     x_array[i] *= alpha;
594   return CEED_ERROR_SUCCESS;
595 }
596 
597 //------------------------------------------------------------------------------
598 // Compute x = alpha x on device (impl in .cu file)
599 //------------------------------------------------------------------------------
600 int CeedDeviceScale_Hip(CeedScalar *x_array, CeedScalar alpha,
601                         CeedInt length);
602 
603 //------------------------------------------------------------------------------
604 // Compute x = alpha x
605 //------------------------------------------------------------------------------
606 static int CeedVectorScale_Hip(CeedVector x, CeedScalar alpha) {
607   int ierr;
608   Ceed ceed;
609   ierr = CeedVectorGetCeed(x, &ceed); CeedChkBackend(ierr);
610   CeedVector_Hip *x_impl;
611   ierr = CeedVectorGetData(x, &x_impl); CeedChkBackend(ierr);
612   CeedSize length;
613   ierr = CeedVectorGetLength(x, &length); CeedChkBackend(ierr);
614 
615   // Set value for synced device/host array
616   if (x_impl->d_array) {
617     ierr = CeedDeviceScale_Hip(x_impl->d_array, alpha, length);
618     CeedChkBackend(ierr);
619   }
620   if (x_impl->h_array) {
621     ierr = CeedHostScale_Hip(x_impl->h_array, alpha, length); CeedChkBackend(ierr);
622   }
623 
624   return CEED_ERROR_SUCCESS;
625 }
626 
627 //------------------------------------------------------------------------------
628 // Compute y = alpha x + y on the host
629 //------------------------------------------------------------------------------
630 static int CeedHostAXPY_Hip(CeedScalar *y_array, CeedScalar alpha,
631                             CeedScalar *x_array, CeedInt length) {
632   for (int i = 0; i < length; i++)
633     y_array[i] += alpha * x_array[i];
634   return CEED_ERROR_SUCCESS;
635 }
636 
637 //------------------------------------------------------------------------------
638 // Compute y = alpha x + y on device (impl in .cu file)
639 //------------------------------------------------------------------------------
640 int CeedDeviceAXPY_Hip(CeedScalar *y_array, CeedScalar alpha,
641                        CeedScalar *x_array, CeedInt length);
642 
643 //------------------------------------------------------------------------------
644 // Compute y = alpha x + y
645 //------------------------------------------------------------------------------
646 static int CeedVectorAXPY_Hip(CeedVector y, CeedScalar alpha, CeedVector x) {
647   int ierr;
648   Ceed ceed;
649   ierr = CeedVectorGetCeed(y, &ceed); CeedChkBackend(ierr);
650   CeedVector_Hip *y_impl, *x_impl;
651   ierr = CeedVectorGetData(y, &y_impl); CeedChkBackend(ierr);
652   ierr = CeedVectorGetData(x, &x_impl); CeedChkBackend(ierr);
653   CeedSize length;
654   ierr = CeedVectorGetLength(y, &length); CeedChkBackend(ierr);
655 
656   // Set value for synced device/host array
657   if (y_impl->d_array) {
658     ierr = CeedVectorSyncArray(x, CEED_MEM_DEVICE); CeedChkBackend(ierr);
659     ierr = CeedDeviceAXPY_Hip(y_impl->d_array, alpha, x_impl->d_array, length);
660     CeedChkBackend(ierr);
661   }
662   if (y_impl->h_array) {
663     ierr = CeedVectorSyncArray(x, CEED_MEM_HOST); CeedChkBackend(ierr);
664     ierr = CeedHostAXPY_Hip(y_impl->h_array, alpha, x_impl->h_array, length);
665     CeedChkBackend(ierr);
666   }
667 
668   return CEED_ERROR_SUCCESS;
669 }
670 
671 //------------------------------------------------------------------------------
672 // Compute the pointwise multiplication w = x .* y on the host
673 //------------------------------------------------------------------------------
674 static int CeedHostPointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array,
675                                      CeedScalar *y_array, CeedInt length) {
676   for (int i = 0; i < length; i++)
677     w_array[i] = x_array[i] * y_array[i];
678   return CEED_ERROR_SUCCESS;
679 }
680 
681 //------------------------------------------------------------------------------
682 // Compute the pointwise multiplication w = x .* y on device (impl in .cu file)
683 //------------------------------------------------------------------------------
684 int CeedDevicePointwiseMult_Hip(CeedScalar *w_array, CeedScalar *x_array,
685                                 CeedScalar *y_array, CeedInt length);
686 
687 //------------------------------------------------------------------------------
688 // Compute the pointwise multiplication w = x .* y
689 //------------------------------------------------------------------------------
690 static int CeedVectorPointwiseMult_Hip(CeedVector w, CeedVector x,
691                                        CeedVector y) {
692   int ierr;
693   Ceed ceed;
694   ierr = CeedVectorGetCeed(w, &ceed); CeedChkBackend(ierr);
695   CeedVector_Hip *w_impl, *x_impl, *y_impl;
696   ierr = CeedVectorGetData(w, &w_impl); CeedChkBackend(ierr);
697   ierr = CeedVectorGetData(x, &x_impl); CeedChkBackend(ierr);
698   ierr = CeedVectorGetData(y, &y_impl); CeedChkBackend(ierr);
699   CeedSize length;
700   ierr = CeedVectorGetLength(w, &length); CeedChkBackend(ierr);
701 
702   // Set value for synced device/host array
703   if (!w_impl->d_array && !w_impl->h_array) {
704     ierr = CeedVectorSetValue(w, 0.0); CeedChkBackend(ierr);
705   }
706   if (w_impl->d_array) {
707     ierr = CeedVectorSyncArray(x, CEED_MEM_DEVICE); CeedChkBackend(ierr);
708     ierr = CeedVectorSyncArray(y, CEED_MEM_DEVICE); CeedChkBackend(ierr);
709     ierr = CeedDevicePointwiseMult_Hip(w_impl->d_array, x_impl->d_array,
710                                        y_impl->d_array, length);
711     CeedChkBackend(ierr);
712   }
713   if (w_impl->h_array) {
714     ierr = CeedVectorSyncArray(x, CEED_MEM_HOST); CeedChkBackend(ierr);
715     ierr = CeedVectorSyncArray(y, CEED_MEM_HOST); CeedChkBackend(ierr);
716     ierr = CeedHostPointwiseMult_Hip(w_impl->h_array, x_impl->h_array,
717                                      y_impl->h_array, length);
718     CeedChkBackend(ierr);
719   }
720 
721   return CEED_ERROR_SUCCESS;
722 }
723 
724 //------------------------------------------------------------------------------
725 // Destroy the vector
726 //------------------------------------------------------------------------------
727 static int CeedVectorDestroy_Hip(const CeedVector vec) {
728   int ierr;
729   Ceed ceed;
730   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
731   CeedVector_Hip *impl;
732   ierr = CeedVectorGetData(vec, &impl); CeedChkBackend(ierr);
733 
734   ierr = hipFree(impl->d_array_owned); CeedChk_Hip(ceed, ierr);
735   ierr = CeedFree(&impl->h_array_owned); CeedChkBackend(ierr);
736   ierr = CeedFree(&impl); CeedChkBackend(ierr);
737 
738   return CEED_ERROR_SUCCESS;
739 }
740 
741 //------------------------------------------------------------------------------
742 // Create a vector of the specified length (does not allocate memory)
743 //------------------------------------------------------------------------------
744 int CeedVectorCreate_Hip(CeedSize n, CeedVector vec) {
745   CeedVector_Hip *impl;
746   int ierr;
747   Ceed ceed;
748   ierr = CeedVectorGetCeed(vec, &ceed); CeedChkBackend(ierr);
749 
750   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "HasValidArray",
751                                 CeedVectorHasValidArray_Hip); CeedChkBackend(ierr);
752   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "HasBorrowedArrayOfType",
753                                 CeedVectorHasBorrowedArrayOfType_Hip);
754   CeedChkBackend(ierr);
755   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SetArray",
756                                 CeedVectorSetArray_Hip); CeedChkBackend(ierr);
757   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "TakeArray",
758                                 CeedVectorTakeArray_Hip); CeedChkBackend(ierr);
759   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SetValue",
760                                 (int (*)())(CeedVectorSetValue_Hip)); CeedChkBackend(ierr);
761   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "SyncArray",
762                                 CeedVectorSyncArray_Hip); CeedChkBackend(ierr);
763   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArray",
764                                 CeedVectorGetArray_Hip); CeedChkBackend(ierr);
765   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayRead",
766                                 CeedVectorGetArrayRead_Hip); CeedChkBackend(ierr);
767   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "GetArrayWrite",
768                                 CeedVectorGetArrayWrite_Hip); CeedChkBackend(ierr);
769   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "Norm",
770                                 CeedVectorNorm_Hip); CeedChkBackend(ierr);
771   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "Reciprocal",
772                                 CeedVectorReciprocal_Hip); CeedChkBackend(ierr);
773   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "Scale",
774                                 (int (*)())(CeedVectorScale_Hip)); CeedChkBackend(ierr);
775   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "AXPY",
776                                 (int (*)())(CeedVectorAXPY_Hip)); CeedChkBackend(ierr);
777   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "PointwiseMult",
778                                 CeedVectorPointwiseMult_Hip); CeedChkBackend(ierr);
779   ierr = CeedSetBackendFunction(ceed, "Vector", vec, "Destroy",
780                                 CeedVectorDestroy_Hip); CeedChkBackend(ierr);
781 
782   ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr);
783   ierr = CeedVectorSetData(vec, impl); CeedChkBackend(ierr);
784 
785   return CEED_ERROR_SUCCESS;
786 }
787