1*940a72f1SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2*940a72f1SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*940a72f1SSebastian Grimberg // 4*940a72f1SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5*940a72f1SSebastian Grimberg // 6*940a72f1SSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7*940a72f1SSebastian Grimberg 8*940a72f1SSebastian Grimberg #include "ceed-magma-gemm-nontensor.h" 9*940a72f1SSebastian Grimberg 10*940a72f1SSebastian Grimberg #include "ceed-magma-gemm-selector.h" 11*940a72f1SSebastian Grimberg 12*940a72f1SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP 13*940a72f1SSebastian Grimberg #define devblasDgemmStridedBatched hipblasDgemmStridedBatched 14*940a72f1SSebastian Grimberg #define devblasSgemmStridedBatched hipblasSgemmStridedBatched 15*940a72f1SSebastian Grimberg #define magma_queue_get_devblas_handle magma_queue_get_hipblas_handle 16*940a72f1SSebastian Grimberg #define devblas_trans_const hipblas_trans_const 17*940a72f1SSebastian Grimberg #else 18*940a72f1SSebastian Grimberg #define devblasDgemmStridedBatched cublasDgemmStridedBatched 19*940a72f1SSebastian Grimberg #define devblasSgemmStridedBatched cublasSgemmStridedBatched 20*940a72f1SSebastian Grimberg #define magma_queue_get_devblas_handle magma_queue_get_cublas_handle 21*940a72f1SSebastian Grimberg #define devblas_trans_const cublas_trans_const 22*940a72f1SSebastian Grimberg #endif 23*940a72f1SSebastian Grimberg 24*940a72f1SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 25*940a72f1SSebastian Grimberg static inline int magmablas_gemm(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha, 26*940a72f1SSebastian Grimberg const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C, 27*940a72f1SSebastian Grimberg magma_int_t lddc, magma_queue_t queue) { 28*940a72f1SSebastian Grimberg if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 29*940a72f1SSebastian Grimberg magmablas_sgemm(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, (const float *)d_B, lddb, (float)beta, (float *)d_C, lddc, 30*940a72f1SSebastian Grimberg queue); 31*940a72f1SSebastian Grimberg } else { 32*940a72f1SSebastian Grimberg magmablas_dgemm(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, (const double *)d_B, lddb, (double)beta, (double *)d_C, lddc, 33*940a72f1SSebastian Grimberg queue); 34*940a72f1SSebastian Grimberg } 35*940a72f1SSebastian Grimberg return 0; 36*940a72f1SSebastian Grimberg } 37*940a72f1SSebastian Grimberg 38*940a72f1SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 39*940a72f1SSebastian Grimberg static inline int magmablas_gemm_batched_strided(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, 40*940a72f1SSebastian Grimberg CeedScalar alpha, const CeedScalar *d_A, magma_int_t ldda, magma_int_t strideA, 41*940a72f1SSebastian Grimberg const CeedScalar *d_B, magma_int_t lddb, magma_int_t strideB, CeedScalar beta, CeedScalar *d_C, 42*940a72f1SSebastian Grimberg magma_int_t lddc, magma_int_t strideC, magma_int_t batchCount, magma_queue_t queue) { 43*940a72f1SSebastian Grimberg if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 44*940a72f1SSebastian Grimberg magmablas_sgemm_batched_strided(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, strideA, (const float *)d_B, lddb, strideB, 45*940a72f1SSebastian Grimberg (float)beta, (float *)d_C, lddc, strideC, batchCount, queue); 46*940a72f1SSebastian Grimberg } else { 47*940a72f1SSebastian Grimberg magmablas_dgemm_batched_strided(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, strideA, (const double *)d_B, lddb, strideB, 48*940a72f1SSebastian Grimberg (double)beta, (double *)d_C, lddc, strideC, batchCount, queue); 49*940a72f1SSebastian Grimberg } 50*940a72f1SSebastian Grimberg return 0; 51*940a72f1SSebastian Grimberg } 52*940a72f1SSebastian Grimberg 53*940a72f1SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 54*940a72f1SSebastian Grimberg static inline int devblas_gemm(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha, 55*940a72f1SSebastian Grimberg const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C, 56*940a72f1SSebastian Grimberg magma_int_t lddc, magma_queue_t queue) { 57*940a72f1SSebastian Grimberg if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 58*940a72f1SSebastian Grimberg magma_sgemm(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, (const float *)d_B, lddb, (float)beta, (float *)d_C, lddc, queue); 59*940a72f1SSebastian Grimberg } else { 60*940a72f1SSebastian Grimberg magma_dgemm(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, (const double *)d_B, lddb, (double)beta, (double *)d_C, lddc, 61*940a72f1SSebastian Grimberg queue); 62*940a72f1SSebastian Grimberg } 63*940a72f1SSebastian Grimberg return 0; 64*940a72f1SSebastian Grimberg } 65*940a72f1SSebastian Grimberg 66*940a72f1SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 67*940a72f1SSebastian Grimberg static inline int devblas_gemm_batched_strided(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, 68*940a72f1SSebastian Grimberg CeedScalar alpha, const CeedScalar *d_A, magma_int_t ldda, magma_int_t strideA, const CeedScalar *d_B, 69*940a72f1SSebastian Grimberg magma_int_t lddb, magma_int_t strideB, CeedScalar beta, CeedScalar *d_C, magma_int_t lddc, 70*940a72f1SSebastian Grimberg magma_int_t strideC, magma_int_t batchCount, magma_queue_t queue) { 71*940a72f1SSebastian Grimberg if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 72*940a72f1SSebastian Grimberg devblasSgemmStridedBatched(magma_queue_get_devblas_handle(queue), devblas_trans_const(trans_A), devblas_trans_const(trans_B), (int)m, (int)n, 73*940a72f1SSebastian Grimberg (int)k, (const float *)&alpha, (const float *)d_A, (int)ldda, strideA, (const float *)d_B, (int)lddb, strideB, 74*940a72f1SSebastian Grimberg (const float *)&beta, (float *)d_C, (int)lddc, strideC, (int)batchCount); 75*940a72f1SSebastian Grimberg } else { 76*940a72f1SSebastian Grimberg devblasDgemmStridedBatched(magma_queue_get_devblas_handle(queue), devblas_trans_const(trans_A), devblas_trans_const(trans_B), (int)m, (int)n, 77*940a72f1SSebastian Grimberg (int)k, (const double *)&alpha, (const double *)d_A, (int)ldda, strideA, (const double *)d_B, (int)lddb, strideB, 78*940a72f1SSebastian Grimberg (const double *)&beta, (double *)d_C, (int)lddc, strideC, (int)batchCount); 79*940a72f1SSebastian Grimberg } 80*940a72f1SSebastian Grimberg return 0; 81*940a72f1SSebastian Grimberg } 82*940a72f1SSebastian Grimberg 83*940a72f1SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 84*940a72f1SSebastian Grimberg int magma_gemm_nontensor(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha, 85*940a72f1SSebastian Grimberg const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C, 86*940a72f1SSebastian Grimberg magma_int_t lddc, magma_queue_t queue) { 87*940a72f1SSebastian Grimberg magma_int_t nbatch, use_magmablas; 88*940a72f1SSebastian Grimberg magma_int_t arch = magma_getdevice_arch(); 89*940a72f1SSebastian Grimberg 90*940a72f1SSebastian Grimberg // check for specific transpositions (NN and TN only) 91*940a72f1SSebastian Grimberg bool NN = trans_A == MagmaNoTrans && trans_B == MagmaNoTrans; 92*940a72f1SSebastian Grimberg bool TN = trans_A == MagmaTrans && trans_B == MagmaNoTrans; 93*940a72f1SSebastian Grimberg if (!(NN || TN)) { 94*940a72f1SSebastian Grimberg // default case -- no specific tuning 95*940a72f1SSebastian Grimberg devblas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue); 96*940a72f1SSebastian Grimberg return 0; 97*940a72f1SSebastian Grimberg } 98*940a72f1SSebastian Grimberg 99*940a72f1SSebastian Grimberg // get tuning decision 100*940a72f1SSebastian Grimberg char trans = (trans_A == MagmaNoTrans) ? 'n' : 't'; 101*940a72f1SSebastian Grimberg char precision = (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) ? 's' : 'd'; 102*940a72f1SSebastian Grimberg gemm_selector(arch, precision, trans, m, n, k, &nbatch, &use_magmablas); 103*940a72f1SSebastian Grimberg 104*940a72f1SSebastian Grimberg // perform the gemm operation 105*940a72f1SSebastian Grimberg if (nbatch == n) { 106*940a72f1SSebastian Grimberg // no batching 107*940a72f1SSebastian Grimberg if (use_magmablas) { 108*940a72f1SSebastian Grimberg magmablas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue); 109*940a72f1SSebastian Grimberg } else { 110*940a72f1SSebastian Grimberg devblas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue); 111*940a72f1SSebastian Grimberg } 112*940a72f1SSebastian Grimberg } else { 113*940a72f1SSebastian Grimberg // use batch kernels 114*940a72f1SSebastian Grimberg magma_int_t batchCount = n / nbatch; 115*940a72f1SSebastian Grimberg magma_int_t n2 = n - (batchCount * nbatch); 116*940a72f1SSebastian Grimberg magma_int_t strideA = 0; 117*940a72f1SSebastian Grimberg magma_int_t strideB = lddb * nbatch; 118*940a72f1SSebastian Grimberg magma_int_t strideC = lddc * nbatch; 119*940a72f1SSebastian Grimberg 120*940a72f1SSebastian Grimberg if (use_magmablas) { 121*940a72f1SSebastian Grimberg if (batchCount > 0) { 122*940a72f1SSebastian Grimberg magmablas_gemm_batched_strided(trans_A, trans_B, m, nbatch, k, alpha, d_A, ldda, strideA, d_B, lddb, strideB, beta, d_C, lddc, strideC, 123*940a72f1SSebastian Grimberg batchCount, queue); 124*940a72f1SSebastian Grimberg } 125*940a72f1SSebastian Grimberg 126*940a72f1SSebastian Grimberg // cleanup 127*940a72f1SSebastian Grimberg if (n2 > 0) { 128*940a72f1SSebastian Grimberg devblas_gemm(trans_A, trans_B, m, n2, k, alpha, d_A, ldda, d_B + batchCount * strideB, lddb, beta, d_C + batchCount * strideC, lddc, queue); 129*940a72f1SSebastian Grimberg } 130*940a72f1SSebastian Grimberg } else { 131*940a72f1SSebastian Grimberg if (batchCount > 0) { 132*940a72f1SSebastian Grimberg devblas_gemm_batched_strided(trans_A, trans_B, m, nbatch, k, alpha, d_A, ldda, strideA, d_B, lddb, strideB, beta, d_C, lddc, strideC, 133*940a72f1SSebastian Grimberg batchCount, queue); 134*940a72f1SSebastian Grimberg } 135*940a72f1SSebastian Grimberg 136*940a72f1SSebastian Grimberg // cleanup 137*940a72f1SSebastian Grimberg if (n2 > 0) { 138*940a72f1SSebastian Grimberg devblas_gemm_batched_strided(trans_A, trans_B, m, n2, k, alpha, d_A, ldda, strideA, d_B + batchCount * strideB, lddb, strideB, beta, 139*940a72f1SSebastian Grimberg d_C + batchCount * strideC, lddc, strideC, 1, queue); 140*940a72f1SSebastian Grimberg } 141*940a72f1SSebastian Grimberg } 142*940a72f1SSebastian Grimberg } 143*940a72f1SSebastian Grimberg 144*940a72f1SSebastian Grimberg // wait for the operation to complete 145*940a72f1SSebastian Grimberg ceed_magma_queue_sync(queue); 146*940a72f1SSebastian Grimberg 147*940a72f1SSebastian Grimberg return 0; 148*940a72f1SSebastian Grimberg } 149