1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED: http://github.com/ceed
7 #pragma once
8
9 #include <ceed.h>
10 #include <ceed/backend.h>
11 #include <hip/hip_runtime.h>
12 #if (HIP_VERSION >= 50200000)
13 #include <hipblas/hipblas.h> // IWYU pragma: export
14 #else
15 #include <hipblas.h> // IWYU pragma: export
16 #endif
17
18 #define QUOTE(...) #__VA_ARGS__
19
20 #define CeedChk_Hip(ceed, x) \
21 do { \
22 hipError_t hip_result = x; \
23 if (hip_result != hipSuccess) { \
24 const char *msg = hipGetErrorName(hip_result); \
25 return CeedError((ceed), CEED_ERROR_BACKEND, msg); \
26 } \
27 } while (0)
28
29 #define CeedChk_Hipblas(ceed, x) \
30 do { \
31 hipblasStatus_t hipblas_result = x; \
32 if (hipblas_result != HIPBLAS_STATUS_SUCCESS) { \
33 const char *msg = hipblasGetErrorName(hipblas_result); \
34 return CeedError((ceed), CEED_ERROR_BACKEND, msg); \
35 } \
36 } while (0)
37
38 #define CeedCallHip(ceed, ...) \
39 do { \
40 hipError_t ierr_q_ = __VA_ARGS__; \
41 CeedChk_Hip(ceed, ierr_q_); \
42 } while (0)
43
44 #define CeedCallHipblas(ceed, ...) \
45 do { \
46 hipblasStatus_t ierr_q_ = __VA_ARGS__; \
47 CeedChk_Hipblas(ceed, ierr_q_); \
48 } while (0)
49
50 #define CASE(name) \
51 case name: \
52 return #name
53 // LCOV_EXCL_START
hipblasGetErrorName(hipblasStatus_t error)54 CEED_UNUSED static const char *hipblasGetErrorName(hipblasStatus_t error) {
55 switch (error) {
56 CASE(HIPBLAS_STATUS_SUCCESS);
57 CASE(HIPBLAS_STATUS_NOT_INITIALIZED);
58 CASE(HIPBLAS_STATUS_ALLOC_FAILED);
59 CASE(HIPBLAS_STATUS_INVALID_VALUE);
60 CASE(HIPBLAS_STATUS_ARCH_MISMATCH);
61 CASE(HIPBLAS_STATUS_MAPPING_ERROR);
62 CASE(HIPBLAS_STATUS_EXECUTION_FAILED);
63 CASE(HIPBLAS_STATUS_INTERNAL_ERROR);
64 default:
65 return "HIPBLAS_STATUS_UNKNOWN_ERROR";
66 }
67 }
68 // LCOV_EXCL_STOP
69
70 typedef struct {
71 int device_id;
72 hipblasHandle_t hipblas_handle;
73 struct hipDeviceProp_t device_prop;
74 int opt_block_size;
75 int has_unified_addressing;
76 } Ceed_Hip;
77
78 CEED_INTERN int CeedInit_Hip(Ceed ceed, const char *resource);
79
80 CEED_INTERN int CeedDestroy_Hip(Ceed ceed);
81
82 CEED_INTERN int CeedSetDeviceBoolArray_Hip(Ceed ceed, const bool *source_array, CeedCopyMode copy_mode, CeedSize num_values,
83 const bool **target_array_owned, const bool **target_array_borrowed, const bool **target_array);
84 CEED_INTERN int CeedSetDeviceCeedInt8Array_Hip(Ceed ceed, const CeedInt8 *source_array, CeedCopyMode copy_mode, CeedSize num_values,
85 const CeedInt8 **target_array_owned, const CeedInt8 **target_array_borrowed,
86 const CeedInt8 **target_array);
87 CEED_INTERN int CeedSetDeviceCeedIntArray_Hip(Ceed ceed, const CeedInt *source_array, CeedCopyMode copy_mode, CeedSize num_values,
88 const CeedInt **target_array_owned, const CeedInt **target_array_borrowed,
89 const CeedInt **target_array);
90 CEED_INTERN int CeedSetDeviceCeedScalarArray_Hip(Ceed ceed, const CeedScalar *source_array, CeedCopyMode copy_mode, CeedSize num_values,
91 const CeedScalar **target_array_owned, const CeedScalar **target_array_borrowed,
92 const CeedScalar **target_array);
93