1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include "ceed-magma-gemm-nontensor.h" 9 #include "ceed-magma-gemm-selector.h" 10 11 #ifdef CEED_MAGMA_USE_HIP 12 #define devblasDgemmStridedBatched hipblasDgemmStridedBatched 13 #define devblasSgemmStridedBatched hipblasSgemmStridedBatched 14 #define magma_queue_get_devblas_handle magma_queue_get_hipblas_handle 15 #define devblas_trans_const hipblas_trans_const 16 #else 17 #define devblasDgemmStridedBatched cublasDgemmStridedBatched 18 #define devblasSgemmStridedBatched cublasSgemmStridedBatched 19 #define magma_queue_get_devblas_handle magma_queue_get_cublas_handle 20 #define devblas_trans_const cublas_trans_const 21 #endif 22 23 //////////////////////////////////////////////////////////////////////////////// 24 static inline int magmablas_gemm(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha, 25 const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C, 26 magma_int_t lddc, magma_queue_t queue) { 27 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 28 magmablas_sgemm(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, (const float *)d_B, lddb, (float)beta, (float *)d_C, lddc, 29 queue); 30 } else { 31 magmablas_dgemm(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, (const double *)d_B, lddb, (double)beta, (double *)d_C, lddc, 32 queue); 33 } 34 return 0; 35 } 36 37 //////////////////////////////////////////////////////////////////////////////// 38 static inline int magmablas_gemm_batched_strided(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, 39 CeedScalar alpha, const CeedScalar *d_A, magma_int_t ldda, magma_int_t strideA, 40 const CeedScalar *d_B, magma_int_t lddb, magma_int_t strideB, CeedScalar beta, CeedScalar *d_C, 41 magma_int_t lddc, magma_int_t strideC, magma_int_t batchCount, magma_queue_t queue) { 42 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 43 magmablas_sgemm_batched_strided(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, strideA, (const float *)d_B, lddb, strideB, 44 (float)beta, (float *)d_C, lddc, strideC, batchCount, queue); 45 } else { 46 magmablas_dgemm_batched_strided(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, strideA, (const double *)d_B, lddb, strideB, 47 (double)beta, (double *)d_C, lddc, strideC, batchCount, queue); 48 } 49 return 0; 50 } 51 52 //////////////////////////////////////////////////////////////////////////////// 53 static inline int devblas_gemm(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha, 54 const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C, 55 magma_int_t lddc, magma_queue_t queue) { 56 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 57 magma_sgemm(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, (const float *)d_B, lddb, (float)beta, (float *)d_C, lddc, queue); 58 } else { 59 magma_dgemm(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, (const double *)d_B, lddb, (double)beta, (double *)d_C, lddc, 60 queue); 61 } 62 return 0; 63 } 64 65 //////////////////////////////////////////////////////////////////////////////// 66 static inline int devblas_gemm_batched_strided(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, 67 CeedScalar alpha, const CeedScalar *d_A, magma_int_t ldda, magma_int_t strideA, const CeedScalar *d_B, 68 magma_int_t lddb, magma_int_t strideB, CeedScalar beta, CeedScalar *d_C, magma_int_t lddc, 69 magma_int_t strideC, magma_int_t batchCount, magma_queue_t queue) { 70 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 71 devblasSgemmStridedBatched(magma_queue_get_devblas_handle(queue), devblas_trans_const(trans_A), devblas_trans_const(trans_B), (int)m, (int)n, 72 (int)k, (const float *)&alpha, (const float *)d_A, (int)ldda, strideA, (const float *)d_B, (int)lddb, strideB, 73 (const float *)&beta, (float *)d_C, (int)lddc, strideC, (int)batchCount); 74 } else { 75 devblasDgemmStridedBatched(magma_queue_get_devblas_handle(queue), devblas_trans_const(trans_A), devblas_trans_const(trans_B), (int)m, (int)n, 76 (int)k, (const double *)&alpha, (const double *)d_A, (int)ldda, strideA, (const double *)d_B, (int)lddb, strideB, 77 (const double *)&beta, (double *)d_C, (int)lddc, strideC, (int)batchCount); 78 } 79 return 0; 80 } 81 82 //////////////////////////////////////////////////////////////////////////////// 83 int magma_gemm_nontensor(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha, 84 const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C, 85 magma_int_t lddc, magma_queue_t queue) { 86 magma_int_t nbatch, use_magmablas; 87 magma_int_t arch = magma_getdevice_arch(); 88 89 // check for specific transpositions (NN and TN only) 90 bool NN = trans_A == MagmaNoTrans && trans_B == MagmaNoTrans; 91 bool TN = trans_A == MagmaTrans && trans_B == MagmaNoTrans; 92 if (!(NN || TN)) { 93 // default case -- no specific tuning 94 devblas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue); 95 return 0; 96 } 97 98 // get tuning decision 99 char trans = (trans_A == MagmaNoTrans) ? 'n' : 't'; 100 char precision = (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) ? 's' : 'd'; 101 gemm_selector(arch, precision, trans, m, n, k, &nbatch, &use_magmablas); 102 103 // perform the gemm operation 104 if (nbatch == n) { 105 // no batching 106 if (use_magmablas) { 107 magmablas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue); 108 } else { 109 devblas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue); 110 } 111 } else { 112 // use batch kernels 113 magma_int_t batchCount = n / nbatch; 114 magma_int_t n2 = n - (batchCount * nbatch); 115 magma_int_t strideA = 0; 116 magma_int_t strideB = lddb * nbatch; 117 magma_int_t strideC = lddc * nbatch; 118 119 if (use_magmablas) { 120 if (batchCount > 0) { 121 magmablas_gemm_batched_strided(trans_A, trans_B, m, nbatch, k, alpha, d_A, ldda, strideA, d_B, lddb, strideB, beta, d_C, lddc, strideC, 122 batchCount, queue); 123 } 124 125 // cleanup 126 if (n2 > 0) { 127 devblas_gemm(trans_A, trans_B, m, n2, k, alpha, d_A, ldda, d_B + batchCount * strideB, lddb, beta, d_C + batchCount * strideC, lddc, queue); 128 } 129 } else { 130 if (batchCount > 0) { 131 devblas_gemm_batched_strided(trans_A, trans_B, m, nbatch, k, alpha, d_A, ldda, strideA, d_B, lddb, strideB, beta, d_C, lddc, strideC, 132 batchCount, queue); 133 } 134 135 // cleanup 136 if (n2 > 0) { 137 devblas_gemm_batched_strided(trans_A, trans_B, m, n2, k, alpha, d_A, ldda, strideA, d_B + batchCount * strideB, lddb, strideB, beta, 138 d_C + batchCount * strideC, lddc, strideC, 1, queue); 139 } 140 } 141 } 142 143 // wait for the operation to complete 144 ceed_magma_queue_sync(queue); 145 146 return 0; 147 } 148