1*9ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors. 2940a72f1SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3940a72f1SSebastian Grimberg // 4940a72f1SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5940a72f1SSebastian Grimberg // 6940a72f1SSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7940a72f1SSebastian Grimberg 8940a72f1SSebastian Grimberg #include "ceed-magma-gemm-nontensor.h" 9940a72f1SSebastian Grimberg #include "ceed-magma-gemm-selector.h" 10940a72f1SSebastian Grimberg 11940a72f1SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP 12940a72f1SSebastian Grimberg #define devblasDgemmStridedBatched hipblasDgemmStridedBatched 13940a72f1SSebastian Grimberg #define devblasSgemmStridedBatched hipblasSgemmStridedBatched 14940a72f1SSebastian Grimberg #define magma_queue_get_devblas_handle magma_queue_get_hipblas_handle 15940a72f1SSebastian Grimberg #define devblas_trans_const hipblas_trans_const 16940a72f1SSebastian Grimberg #else 17940a72f1SSebastian Grimberg #define devblasDgemmStridedBatched cublasDgemmStridedBatched 18940a72f1SSebastian Grimberg #define devblasSgemmStridedBatched cublasSgemmStridedBatched 19940a72f1SSebastian Grimberg #define magma_queue_get_devblas_handle magma_queue_get_cublas_handle 20940a72f1SSebastian Grimberg #define devblas_trans_const cublas_trans_const 21940a72f1SSebastian Grimberg #endif 22940a72f1SSebastian Grimberg 23940a72f1SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 24940a72f1SSebastian Grimberg static inline int magmablas_gemm(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha, 25940a72f1SSebastian Grimberg const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C, 26940a72f1SSebastian Grimberg magma_int_t lddc, magma_queue_t queue) { 27940a72f1SSebastian Grimberg if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 28940a72f1SSebastian Grimberg magmablas_sgemm(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, (const float *)d_B, lddb, (float)beta, (float *)d_C, lddc, 29940a72f1SSebastian Grimberg queue); 30940a72f1SSebastian Grimberg } else { 31940a72f1SSebastian Grimberg magmablas_dgemm(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, (const double *)d_B, lddb, (double)beta, (double *)d_C, lddc, 32940a72f1SSebastian Grimberg queue); 33940a72f1SSebastian Grimberg } 34940a72f1SSebastian Grimberg return 0; 35940a72f1SSebastian Grimberg } 36940a72f1SSebastian Grimberg 37940a72f1SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 38940a72f1SSebastian Grimberg static inline int magmablas_gemm_batched_strided(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, 39940a72f1SSebastian Grimberg CeedScalar alpha, const CeedScalar *d_A, magma_int_t ldda, magma_int_t strideA, 40940a72f1SSebastian Grimberg const CeedScalar *d_B, magma_int_t lddb, magma_int_t strideB, CeedScalar beta, CeedScalar *d_C, 41940a72f1SSebastian Grimberg magma_int_t lddc, magma_int_t strideC, magma_int_t batchCount, magma_queue_t queue) { 42940a72f1SSebastian Grimberg if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 43940a72f1SSebastian Grimberg magmablas_sgemm_batched_strided(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, strideA, (const float *)d_B, lddb, strideB, 44940a72f1SSebastian Grimberg (float)beta, (float *)d_C, lddc, strideC, batchCount, queue); 45940a72f1SSebastian Grimberg } else { 46940a72f1SSebastian Grimberg magmablas_dgemm_batched_strided(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, strideA, (const double *)d_B, lddb, strideB, 47940a72f1SSebastian Grimberg (double)beta, (double *)d_C, lddc, strideC, batchCount, queue); 48940a72f1SSebastian Grimberg } 49940a72f1SSebastian Grimberg return 0; 50940a72f1SSebastian Grimberg } 51940a72f1SSebastian Grimberg 52940a72f1SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 53940a72f1SSebastian Grimberg static inline int devblas_gemm(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha, 54940a72f1SSebastian Grimberg const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C, 55940a72f1SSebastian Grimberg magma_int_t lddc, magma_queue_t queue) { 56940a72f1SSebastian Grimberg if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 57940a72f1SSebastian Grimberg magma_sgemm(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, (const float *)d_B, lddb, (float)beta, (float *)d_C, lddc, queue); 58940a72f1SSebastian Grimberg } else { 59940a72f1SSebastian Grimberg magma_dgemm(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, (const double *)d_B, lddb, (double)beta, (double *)d_C, lddc, 60940a72f1SSebastian Grimberg queue); 61940a72f1SSebastian Grimberg } 62940a72f1SSebastian Grimberg return 0; 63940a72f1SSebastian Grimberg } 64940a72f1SSebastian Grimberg 65940a72f1SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 66940a72f1SSebastian Grimberg static inline int devblas_gemm_batched_strided(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, 67940a72f1SSebastian Grimberg CeedScalar alpha, const CeedScalar *d_A, magma_int_t ldda, magma_int_t strideA, const CeedScalar *d_B, 68940a72f1SSebastian Grimberg magma_int_t lddb, magma_int_t strideB, CeedScalar beta, CeedScalar *d_C, magma_int_t lddc, 69940a72f1SSebastian Grimberg magma_int_t strideC, magma_int_t batchCount, magma_queue_t queue) { 70940a72f1SSebastian Grimberg if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) { 71940a72f1SSebastian Grimberg devblasSgemmStridedBatched(magma_queue_get_devblas_handle(queue), devblas_trans_const(trans_A), devblas_trans_const(trans_B), (int)m, (int)n, 72940a72f1SSebastian Grimberg (int)k, (const float *)&alpha, (const float *)d_A, (int)ldda, strideA, (const float *)d_B, (int)lddb, strideB, 73940a72f1SSebastian Grimberg (const float *)&beta, (float *)d_C, (int)lddc, strideC, (int)batchCount); 74940a72f1SSebastian Grimberg } else { 75940a72f1SSebastian Grimberg devblasDgemmStridedBatched(magma_queue_get_devblas_handle(queue), devblas_trans_const(trans_A), devblas_trans_const(trans_B), (int)m, (int)n, 76940a72f1SSebastian Grimberg (int)k, (const double *)&alpha, (const double *)d_A, (int)ldda, strideA, (const double *)d_B, (int)lddb, strideB, 77940a72f1SSebastian Grimberg (const double *)&beta, (double *)d_C, (int)lddc, strideC, (int)batchCount); 78940a72f1SSebastian Grimberg } 79940a72f1SSebastian Grimberg return 0; 80940a72f1SSebastian Grimberg } 81940a72f1SSebastian Grimberg 82940a72f1SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 83940a72f1SSebastian Grimberg int magma_gemm_nontensor(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha, 84940a72f1SSebastian Grimberg const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C, 85940a72f1SSebastian Grimberg magma_int_t lddc, magma_queue_t queue) { 86940a72f1SSebastian Grimberg magma_int_t nbatch, use_magmablas; 87940a72f1SSebastian Grimberg magma_int_t arch = magma_getdevice_arch(); 88940a72f1SSebastian Grimberg 89940a72f1SSebastian Grimberg // check for specific transpositions (NN and TN only) 90940a72f1SSebastian Grimberg bool NN = trans_A == MagmaNoTrans && trans_B == MagmaNoTrans; 91940a72f1SSebastian Grimberg bool TN = trans_A == MagmaTrans && trans_B == MagmaNoTrans; 92940a72f1SSebastian Grimberg if (!(NN || TN)) { 93940a72f1SSebastian Grimberg // default case -- no specific tuning 94940a72f1SSebastian Grimberg devblas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue); 95940a72f1SSebastian Grimberg return 0; 96940a72f1SSebastian Grimberg } 97940a72f1SSebastian Grimberg 98940a72f1SSebastian Grimberg // get tuning decision 99940a72f1SSebastian Grimberg char trans = (trans_A == MagmaNoTrans) ? 'n' : 't'; 100940a72f1SSebastian Grimberg char precision = (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) ? 's' : 'd'; 101940a72f1SSebastian Grimberg gemm_selector(arch, precision, trans, m, n, k, &nbatch, &use_magmablas); 102940a72f1SSebastian Grimberg 103940a72f1SSebastian Grimberg // perform the gemm operation 104940a72f1SSebastian Grimberg if (nbatch == n) { 105940a72f1SSebastian Grimberg // no batching 106940a72f1SSebastian Grimberg if (use_magmablas) { 107940a72f1SSebastian Grimberg magmablas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue); 108940a72f1SSebastian Grimberg } else { 109940a72f1SSebastian Grimberg devblas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue); 110940a72f1SSebastian Grimberg } 111940a72f1SSebastian Grimberg } else { 112940a72f1SSebastian Grimberg // use batch kernels 113940a72f1SSebastian Grimberg magma_int_t batchCount = n / nbatch; 114940a72f1SSebastian Grimberg magma_int_t n2 = n - (batchCount * nbatch); 115940a72f1SSebastian Grimberg magma_int_t strideA = 0; 116940a72f1SSebastian Grimberg magma_int_t strideB = lddb * nbatch; 117940a72f1SSebastian Grimberg magma_int_t strideC = lddc * nbatch; 118940a72f1SSebastian Grimberg 119940a72f1SSebastian Grimberg if (use_magmablas) { 120940a72f1SSebastian Grimberg if (batchCount > 0) { 121940a72f1SSebastian Grimberg magmablas_gemm_batched_strided(trans_A, trans_B, m, nbatch, k, alpha, d_A, ldda, strideA, d_B, lddb, strideB, beta, d_C, lddc, strideC, 122940a72f1SSebastian Grimberg batchCount, queue); 123940a72f1SSebastian Grimberg } 124940a72f1SSebastian Grimberg 125940a72f1SSebastian Grimberg // cleanup 126940a72f1SSebastian Grimberg if (n2 > 0) { 127940a72f1SSebastian Grimberg devblas_gemm(trans_A, trans_B, m, n2, k, alpha, d_A, ldda, d_B + batchCount * strideB, lddb, beta, d_C + batchCount * strideC, lddc, queue); 128940a72f1SSebastian Grimberg } 129940a72f1SSebastian Grimberg } else { 130940a72f1SSebastian Grimberg if (batchCount > 0) { 131940a72f1SSebastian Grimberg devblas_gemm_batched_strided(trans_A, trans_B, m, nbatch, k, alpha, d_A, ldda, strideA, d_B, lddb, strideB, beta, d_C, lddc, strideC, 132940a72f1SSebastian Grimberg batchCount, queue); 133940a72f1SSebastian Grimberg } 134940a72f1SSebastian Grimberg 135940a72f1SSebastian Grimberg // cleanup 136940a72f1SSebastian Grimberg if (n2 > 0) { 137940a72f1SSebastian Grimberg devblas_gemm_batched_strided(trans_A, trans_B, m, n2, k, alpha, d_A, ldda, strideA, d_B + batchCount * strideB, lddb, strideB, beta, 138940a72f1SSebastian Grimberg d_C + batchCount * strideC, lddc, strideC, 1, queue); 139940a72f1SSebastian Grimberg } 140940a72f1SSebastian Grimberg } 141940a72f1SSebastian Grimberg } 142940a72f1SSebastian Grimberg 143940a72f1SSebastian Grimberg // wait for the operation to complete 144940a72f1SSebastian Grimberg ceed_magma_queue_sync(queue); 145940a72f1SSebastian Grimberg 146940a72f1SSebastian Grimberg return 0; 147940a72f1SSebastian Grimberg } 148