1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include "ceed-magma-gemm-nontensor.h" 9 10 #include "ceed-magma-gemm-selector.h" 11 12 #ifdef CEED_MAGMA_USE_HIP 13 #define devblasDgemmStridedBatched hipblasDgemmStridedBatched 14 #define devblasSgemmStridedBatched hipblasSgemmStridedBatched 15 #define magma_queue_get_devblas_handle magma_queue_get_hipblas_handle 16 #define devblas_trans_const hipblas_trans_const 17 #else 18 #define devblasDgemmStridedBatched cublasDgemmStridedBatched 19 #define devblasSgemmStridedBatched cublasSgemmStridedBatched 20 #define magma_queue_get_devblas_handle magma_queue_get_cublas_handle 21 #define devblas_trans_const cublas_trans_const 22 #endif 23 24 //////////////////////////////////////////////////////////////////////////////// 25 static inline int magmablas_gemm(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha, 26 const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C, 27 magma_int_t lddc, magma_queue_t queue) { 28 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 29 magmablas_sgemm(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, (const float *)d_B, lddb, (float)beta, (float *)d_C, lddc, 30 queue); 31 } else { 32 magmablas_dgemm(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, (const double *)d_B, lddb, (double)beta, (double *)d_C, lddc, 33 queue); 34 } 35 return 0; 36 } 37 38 //////////////////////////////////////////////////////////////////////////////// 39 static inline int magmablas_gemm_batched_strided(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, 40 CeedScalar alpha, const CeedScalar *d_A, magma_int_t ldda, magma_int_t strideA, 41 const CeedScalar *d_B, magma_int_t lddb, magma_int_t strideB, CeedScalar beta, CeedScalar *d_C, 42 magma_int_t lddc, magma_int_t strideC, magma_int_t batchCount, magma_queue_t queue) { 43 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 44 magmablas_sgemm_batched_strided(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, strideA, (const float *)d_B, lddb, strideB, 45 (float)beta, (float *)d_C, lddc, strideC, batchCount, queue); 46 } else { 47 magmablas_dgemm_batched_strided(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, strideA, (const double *)d_B, lddb, strideB, 48 (double)beta, (double *)d_C, lddc, strideC, batchCount, queue); 49 } 50 return 0; 51 } 52 53 //////////////////////////////////////////////////////////////////////////////// 54 static inline int devblas_gemm(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha, 55 const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C, 56 magma_int_t lddc, magma_queue_t queue) { 57 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 58 magma_sgemm(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, (const float *)d_B, lddb, (float)beta, (float *)d_C, lddc, queue); 59 } else { 60 magma_dgemm(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, (const double *)d_B, lddb, (double)beta, (double *)d_C, lddc, 61 queue); 62 } 63 return 0; 64 } 65 66 //////////////////////////////////////////////////////////////////////////////// 67 static inline int devblas_gemm_batched_strided(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, 68 CeedScalar alpha, const CeedScalar *d_A, magma_int_t ldda, magma_int_t strideA, const CeedScalar *d_B, 69 magma_int_t lddb, magma_int_t strideB, CeedScalar beta, CeedScalar *d_C, magma_int_t lddc, 70 magma_int_t strideC, magma_int_t batchCount, magma_queue_t queue) { 71 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 72 devblasSgemmStridedBatched(magma_queue_get_devblas_handle(queue), devblas_trans_const(trans_A), devblas_trans_const(trans_B), (int)m, (int)n, 73 (int)k, (const float *)&alpha, (const float *)d_A, (int)ldda, strideA, (const float *)d_B, (int)lddb, strideB, 74 (const float *)&beta, (float *)d_C, (int)lddc, strideC, (int)batchCount); 75 } else { 76 devblasDgemmStridedBatched(magma_queue_get_devblas_handle(queue), devblas_trans_const(trans_A), devblas_trans_const(trans_B), (int)m, (int)n, 77 (int)k, (const double *)&alpha, (const double *)d_A, (int)ldda, strideA, (const double *)d_B, (int)lddb, strideB, 78 (const double *)&beta, (double *)d_C, (int)lddc, strideC, (int)batchCount); 79 } 80 return 0; 81 } 82 83 //////////////////////////////////////////////////////////////////////////////// 84 int magma_gemm_nontensor(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha, 85 const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C, 86 magma_int_t lddc, magma_queue_t queue) { 87 magma_int_t nbatch, use_magmablas; 88 magma_int_t arch = magma_getdevice_arch(); 89 90 // check for specific transpositions (NN and TN only) 91 bool NN = trans_A == MagmaNoTrans && trans_B == MagmaNoTrans; 92 bool TN = trans_A == MagmaTrans && trans_B == MagmaNoTrans; 93 if (!(NN || TN)) { 94 // default case -- no specific tuning 95 devblas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue); 96 return 0; 97 } 98 99 // get tuning decision 100 char trans = (trans_A == MagmaNoTrans) ? 'n' : 't'; 101 char precision = (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) ? 's' : 'd'; 102 gemm_selector(arch, precision, trans, m, n, k, &nbatch, &use_magmablas); 103 104 // perform the gemm operation 105 if (nbatch == n) { 106 // no batching 107 if (use_magmablas) { 108 magmablas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue); 109 } else { 110 devblas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue); 111 } 112 } else { 113 // use batch kernels 114 magma_int_t batchCount = n / nbatch; 115 magma_int_t n2 = n - (batchCount * nbatch); 116 magma_int_t strideA = 0; 117 magma_int_t strideB = lddb * nbatch; 118 magma_int_t strideC = lddc * nbatch; 119 120 if (use_magmablas) { 121 if (batchCount > 0) { 122 magmablas_gemm_batched_strided(trans_A, trans_B, m, nbatch, k, alpha, d_A, ldda, strideA, d_B, lddb, strideB, beta, d_C, lddc, strideC, 123 batchCount, queue); 124 } 125 126 // cleanup 127 if (n2 > 0) { 128 devblas_gemm(trans_A, trans_B, m, n2, k, alpha, d_A, ldda, d_B + batchCount * strideB, lddb, beta, d_C + batchCount * strideC, lddc, queue); 129 } 130 } else { 131 if (batchCount > 0) { 132 devblas_gemm_batched_strided(trans_A, trans_B, m, nbatch, k, alpha, d_A, ldda, strideA, d_B, lddb, strideB, beta, d_C, lddc, strideC, 133 batchCount, queue); 134 } 135 136 // cleanup 137 if (n2 > 0) { 138 devblas_gemm_batched_strided(trans_A, trans_B, m, n2, k, alpha, d_A, ldda, strideA, d_B + batchCount * strideB, lddb, strideB, beta, 139 d_C + batchCount * strideC, lddc, strideC, 1, queue); 140 } 141 } 142 } 143 144 // wait for the operation to complete 145 ceed_magma_queue_sync(queue); 146 147 return 0; 148 } 149