1*9ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2940a72f1SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3940a72f1SSebastian Grimberg //
4940a72f1SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5940a72f1SSebastian Grimberg //
6940a72f1SSebastian Grimberg // This file is part of CEED: http://github.com/ceed
7940a72f1SSebastian Grimberg
8940a72f1SSebastian Grimberg #include "ceed-magma-gemm-nontensor.h"
9940a72f1SSebastian Grimberg #include "ceed-magma-gemm-selector.h"
10940a72f1SSebastian Grimberg
11940a72f1SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP
12940a72f1SSebastian Grimberg #define devblasDgemmStridedBatched hipblasDgemmStridedBatched
13940a72f1SSebastian Grimberg #define devblasSgemmStridedBatched hipblasSgemmStridedBatched
14940a72f1SSebastian Grimberg #define magma_queue_get_devblas_handle magma_queue_get_hipblas_handle
15940a72f1SSebastian Grimberg #define devblas_trans_const hipblas_trans_const
16940a72f1SSebastian Grimberg #else
17940a72f1SSebastian Grimberg #define devblasDgemmStridedBatched cublasDgemmStridedBatched
18940a72f1SSebastian Grimberg #define devblasSgemmStridedBatched cublasSgemmStridedBatched
19940a72f1SSebastian Grimberg #define magma_queue_get_devblas_handle magma_queue_get_cublas_handle
20940a72f1SSebastian Grimberg #define devblas_trans_const cublas_trans_const
21940a72f1SSebastian Grimberg #endif
22940a72f1SSebastian Grimberg
23940a72f1SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
magmablas_gemm(magma_trans_t trans_A,magma_trans_t trans_B,magma_int_t m,magma_int_t n,magma_int_t k,CeedScalar alpha,const CeedScalar * d_A,magma_int_t ldda,const CeedScalar * d_B,magma_int_t lddb,CeedScalar beta,CeedScalar * d_C,magma_int_t lddc,magma_queue_t queue)24940a72f1SSebastian Grimberg static inline int magmablas_gemm(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha,
25940a72f1SSebastian Grimberg const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C,
26940a72f1SSebastian Grimberg magma_int_t lddc, magma_queue_t queue) {
27940a72f1SSebastian Grimberg if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
28940a72f1SSebastian Grimberg magmablas_sgemm(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, (const float *)d_B, lddb, (float)beta, (float *)d_C, lddc,
29940a72f1SSebastian Grimberg queue);
30940a72f1SSebastian Grimberg } else {
31940a72f1SSebastian Grimberg magmablas_dgemm(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, (const double *)d_B, lddb, (double)beta, (double *)d_C, lddc,
32940a72f1SSebastian Grimberg queue);
33940a72f1SSebastian Grimberg }
34940a72f1SSebastian Grimberg return 0;
35940a72f1SSebastian Grimberg }
36940a72f1SSebastian Grimberg
37940a72f1SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
magmablas_gemm_batched_strided(magma_trans_t trans_A,magma_trans_t trans_B,magma_int_t m,magma_int_t n,magma_int_t k,CeedScalar alpha,const CeedScalar * d_A,magma_int_t ldda,magma_int_t strideA,const CeedScalar * d_B,magma_int_t lddb,magma_int_t strideB,CeedScalar beta,CeedScalar * d_C,magma_int_t lddc,magma_int_t strideC,magma_int_t batchCount,magma_queue_t queue)38940a72f1SSebastian Grimberg static inline int magmablas_gemm_batched_strided(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k,
39940a72f1SSebastian Grimberg CeedScalar alpha, const CeedScalar *d_A, magma_int_t ldda, magma_int_t strideA,
40940a72f1SSebastian Grimberg const CeedScalar *d_B, magma_int_t lddb, magma_int_t strideB, CeedScalar beta, CeedScalar *d_C,
41940a72f1SSebastian Grimberg magma_int_t lddc, magma_int_t strideC, magma_int_t batchCount, magma_queue_t queue) {
42940a72f1SSebastian Grimberg if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
43940a72f1SSebastian Grimberg magmablas_sgemm_batched_strided(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, strideA, (const float *)d_B, lddb, strideB,
44940a72f1SSebastian Grimberg (float)beta, (float *)d_C, lddc, strideC, batchCount, queue);
45940a72f1SSebastian Grimberg } else {
46940a72f1SSebastian Grimberg magmablas_dgemm_batched_strided(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, strideA, (const double *)d_B, lddb, strideB,
47940a72f1SSebastian Grimberg (double)beta, (double *)d_C, lddc, strideC, batchCount, queue);
48940a72f1SSebastian Grimberg }
49940a72f1SSebastian Grimberg return 0;
50940a72f1SSebastian Grimberg }
51940a72f1SSebastian Grimberg
52940a72f1SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
devblas_gemm(magma_trans_t trans_A,magma_trans_t trans_B,magma_int_t m,magma_int_t n,magma_int_t k,CeedScalar alpha,const CeedScalar * d_A,magma_int_t ldda,const CeedScalar * d_B,magma_int_t lddb,CeedScalar beta,CeedScalar * d_C,magma_int_t lddc,magma_queue_t queue)53940a72f1SSebastian Grimberg static inline int devblas_gemm(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha,
54940a72f1SSebastian Grimberg const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C,
55940a72f1SSebastian Grimberg magma_int_t lddc, magma_queue_t queue) {
56940a72f1SSebastian Grimberg if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
57940a72f1SSebastian Grimberg magma_sgemm(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, (const float *)d_B, lddb, (float)beta, (float *)d_C, lddc, queue);
58940a72f1SSebastian Grimberg } else {
59940a72f1SSebastian Grimberg magma_dgemm(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, (const double *)d_B, lddb, (double)beta, (double *)d_C, lddc,
60940a72f1SSebastian Grimberg queue);
61940a72f1SSebastian Grimberg }
62940a72f1SSebastian Grimberg return 0;
63940a72f1SSebastian Grimberg }
64940a72f1SSebastian Grimberg
65940a72f1SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
devblas_gemm_batched_strided(magma_trans_t trans_A,magma_trans_t trans_B,magma_int_t m,magma_int_t n,magma_int_t k,CeedScalar alpha,const CeedScalar * d_A,magma_int_t ldda,magma_int_t strideA,const CeedScalar * d_B,magma_int_t lddb,magma_int_t strideB,CeedScalar beta,CeedScalar * d_C,magma_int_t lddc,magma_int_t strideC,magma_int_t batchCount,magma_queue_t queue)66940a72f1SSebastian Grimberg static inline int devblas_gemm_batched_strided(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k,
67940a72f1SSebastian Grimberg CeedScalar alpha, const CeedScalar *d_A, magma_int_t ldda, magma_int_t strideA, const CeedScalar *d_B,
68940a72f1SSebastian Grimberg magma_int_t lddb, magma_int_t strideB, CeedScalar beta, CeedScalar *d_C, magma_int_t lddc,
69940a72f1SSebastian Grimberg magma_int_t strideC, magma_int_t batchCount, magma_queue_t queue) {
70940a72f1SSebastian Grimberg if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
71940a72f1SSebastian Grimberg devblasSgemmStridedBatched(magma_queue_get_devblas_handle(queue), devblas_trans_const(trans_A), devblas_trans_const(trans_B), (int)m, (int)n,
72940a72f1SSebastian Grimberg (int)k, (const float *)&alpha, (const float *)d_A, (int)ldda, strideA, (const float *)d_B, (int)lddb, strideB,
73940a72f1SSebastian Grimberg (const float *)&beta, (float *)d_C, (int)lddc, strideC, (int)batchCount);
74940a72f1SSebastian Grimberg } else {
75940a72f1SSebastian Grimberg devblasDgemmStridedBatched(magma_queue_get_devblas_handle(queue), devblas_trans_const(trans_A), devblas_trans_const(trans_B), (int)m, (int)n,
76940a72f1SSebastian Grimberg (int)k, (const double *)&alpha, (const double *)d_A, (int)ldda, strideA, (const double *)d_B, (int)lddb, strideB,
77940a72f1SSebastian Grimberg (const double *)&beta, (double *)d_C, (int)lddc, strideC, (int)batchCount);
78940a72f1SSebastian Grimberg }
79940a72f1SSebastian Grimberg return 0;
80940a72f1SSebastian Grimberg }
81940a72f1SSebastian Grimberg
82940a72f1SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
magma_gemm_nontensor(magma_trans_t trans_A,magma_trans_t trans_B,magma_int_t m,magma_int_t n,magma_int_t k,CeedScalar alpha,const CeedScalar * d_A,magma_int_t ldda,const CeedScalar * d_B,magma_int_t lddb,CeedScalar beta,CeedScalar * d_C,magma_int_t lddc,magma_queue_t queue)83940a72f1SSebastian Grimberg int magma_gemm_nontensor(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha,
84940a72f1SSebastian Grimberg const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C,
85940a72f1SSebastian Grimberg magma_int_t lddc, magma_queue_t queue) {
86940a72f1SSebastian Grimberg magma_int_t nbatch, use_magmablas;
87940a72f1SSebastian Grimberg magma_int_t arch = magma_getdevice_arch();
88940a72f1SSebastian Grimberg
89940a72f1SSebastian Grimberg // check for specific transpositions (NN and TN only)
90940a72f1SSebastian Grimberg bool NN = trans_A == MagmaNoTrans && trans_B == MagmaNoTrans;
91940a72f1SSebastian Grimberg bool TN = trans_A == MagmaTrans && trans_B == MagmaNoTrans;
92940a72f1SSebastian Grimberg if (!(NN || TN)) {
93940a72f1SSebastian Grimberg // default case -- no specific tuning
94940a72f1SSebastian Grimberg devblas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue);
95940a72f1SSebastian Grimberg return 0;
96940a72f1SSebastian Grimberg }
97940a72f1SSebastian Grimberg
98940a72f1SSebastian Grimberg // get tuning decision
99940a72f1SSebastian Grimberg char trans = (trans_A == MagmaNoTrans) ? 'n' : 't';
100940a72f1SSebastian Grimberg char precision = (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) ? 's' : 'd';
101940a72f1SSebastian Grimberg gemm_selector(arch, precision, trans, m, n, k, &nbatch, &use_magmablas);
102940a72f1SSebastian Grimberg
103940a72f1SSebastian Grimberg // perform the gemm operation
104940a72f1SSebastian Grimberg if (nbatch == n) {
105940a72f1SSebastian Grimberg // no batching
106940a72f1SSebastian Grimberg if (use_magmablas) {
107940a72f1SSebastian Grimberg magmablas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue);
108940a72f1SSebastian Grimberg } else {
109940a72f1SSebastian Grimberg devblas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue);
110940a72f1SSebastian Grimberg }
111940a72f1SSebastian Grimberg } else {
112940a72f1SSebastian Grimberg // use batch kernels
113940a72f1SSebastian Grimberg magma_int_t batchCount = n / nbatch;
114940a72f1SSebastian Grimberg magma_int_t n2 = n - (batchCount * nbatch);
115940a72f1SSebastian Grimberg magma_int_t strideA = 0;
116940a72f1SSebastian Grimberg magma_int_t strideB = lddb * nbatch;
117940a72f1SSebastian Grimberg magma_int_t strideC = lddc * nbatch;
118940a72f1SSebastian Grimberg
119940a72f1SSebastian Grimberg if (use_magmablas) {
120940a72f1SSebastian Grimberg if (batchCount > 0) {
121940a72f1SSebastian Grimberg magmablas_gemm_batched_strided(trans_A, trans_B, m, nbatch, k, alpha, d_A, ldda, strideA, d_B, lddb, strideB, beta, d_C, lddc, strideC,
122940a72f1SSebastian Grimberg batchCount, queue);
123940a72f1SSebastian Grimberg }
124940a72f1SSebastian Grimberg
125940a72f1SSebastian Grimberg // cleanup
126940a72f1SSebastian Grimberg if (n2 > 0) {
127940a72f1SSebastian Grimberg devblas_gemm(trans_A, trans_B, m, n2, k, alpha, d_A, ldda, d_B + batchCount * strideB, lddb, beta, d_C + batchCount * strideC, lddc, queue);
128940a72f1SSebastian Grimberg }
129940a72f1SSebastian Grimberg } else {
130940a72f1SSebastian Grimberg if (batchCount > 0) {
131940a72f1SSebastian Grimberg devblas_gemm_batched_strided(trans_A, trans_B, m, nbatch, k, alpha, d_A, ldda, strideA, d_B, lddb, strideB, beta, d_C, lddc, strideC,
132940a72f1SSebastian Grimberg batchCount, queue);
133940a72f1SSebastian Grimberg }
134940a72f1SSebastian Grimberg
135940a72f1SSebastian Grimberg // cleanup
136940a72f1SSebastian Grimberg if (n2 > 0) {
137940a72f1SSebastian Grimberg devblas_gemm_batched_strided(trans_A, trans_B, m, n2, k, alpha, d_A, ldda, strideA, d_B + batchCount * strideB, lddb, strideB, beta,
138940a72f1SSebastian Grimberg d_C + batchCount * strideC, lddc, strideC, 1, queue);
139940a72f1SSebastian Grimberg }
140940a72f1SSebastian Grimberg }
141940a72f1SSebastian Grimberg }
142940a72f1SSebastian Grimberg
143940a72f1SSebastian Grimberg // wait for the operation to complete
144940a72f1SSebastian Grimberg ceed_magma_queue_sync(queue);
145940a72f1SSebastian Grimberg
146940a72f1SSebastian Grimberg return 0;
147940a72f1SSebastian Grimberg }
148