1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED: http://github.com/ceed
7
8 #include "ceed-magma-gemm-nontensor.h"
9 #include "ceed-magma-gemm-selector.h"
10
11 #ifdef CEED_MAGMA_USE_HIP
12 #define devblasDgemmStridedBatched hipblasDgemmStridedBatched
13 #define devblasSgemmStridedBatched hipblasSgemmStridedBatched
14 #define magma_queue_get_devblas_handle magma_queue_get_hipblas_handle
15 #define devblas_trans_const hipblas_trans_const
16 #else
17 #define devblasDgemmStridedBatched cublasDgemmStridedBatched
18 #define devblasSgemmStridedBatched cublasSgemmStridedBatched
19 #define magma_queue_get_devblas_handle magma_queue_get_cublas_handle
20 #define devblas_trans_const cublas_trans_const
21 #endif
22
23 ////////////////////////////////////////////////////////////////////////////////
magmablas_gemm(magma_trans_t trans_A,magma_trans_t trans_B,magma_int_t m,magma_int_t n,magma_int_t k,CeedScalar alpha,const CeedScalar * d_A,magma_int_t ldda,const CeedScalar * d_B,magma_int_t lddb,CeedScalar beta,CeedScalar * d_C,magma_int_t lddc,magma_queue_t queue)24 static inline int magmablas_gemm(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha,
25 const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C,
26 magma_int_t lddc, magma_queue_t queue) {
27 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
28 magmablas_sgemm(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, (const float *)d_B, lddb, (float)beta, (float *)d_C, lddc,
29 queue);
30 } else {
31 magmablas_dgemm(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, (const double *)d_B, lddb, (double)beta, (double *)d_C, lddc,
32 queue);
33 }
34 return 0;
35 }
36
37 ////////////////////////////////////////////////////////////////////////////////
magmablas_gemm_batched_strided(magma_trans_t trans_A,magma_trans_t trans_B,magma_int_t m,magma_int_t n,magma_int_t k,CeedScalar alpha,const CeedScalar * d_A,magma_int_t ldda,magma_int_t strideA,const CeedScalar * d_B,magma_int_t lddb,magma_int_t strideB,CeedScalar beta,CeedScalar * d_C,magma_int_t lddc,magma_int_t strideC,magma_int_t batchCount,magma_queue_t queue)38 static inline int magmablas_gemm_batched_strided(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k,
39 CeedScalar alpha, const CeedScalar *d_A, magma_int_t ldda, magma_int_t strideA,
40 const CeedScalar *d_B, magma_int_t lddb, magma_int_t strideB, CeedScalar beta, CeedScalar *d_C,
41 magma_int_t lddc, magma_int_t strideC, magma_int_t batchCount, magma_queue_t queue) {
42 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
43 magmablas_sgemm_batched_strided(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, strideA, (const float *)d_B, lddb, strideB,
44 (float)beta, (float *)d_C, lddc, strideC, batchCount, queue);
45 } else {
46 magmablas_dgemm_batched_strided(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, strideA, (const double *)d_B, lddb, strideB,
47 (double)beta, (double *)d_C, lddc, strideC, batchCount, queue);
48 }
49 return 0;
50 }
51
52 ////////////////////////////////////////////////////////////////////////////////
devblas_gemm(magma_trans_t trans_A,magma_trans_t trans_B,magma_int_t m,magma_int_t n,magma_int_t k,CeedScalar alpha,const CeedScalar * d_A,magma_int_t ldda,const CeedScalar * d_B,magma_int_t lddb,CeedScalar beta,CeedScalar * d_C,magma_int_t lddc,magma_queue_t queue)53 static inline int devblas_gemm(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha,
54 const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C,
55 magma_int_t lddc, magma_queue_t queue) {
56 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
57 magma_sgemm(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, (const float *)d_B, lddb, (float)beta, (float *)d_C, lddc, queue);
58 } else {
59 magma_dgemm(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, (const double *)d_B, lddb, (double)beta, (double *)d_C, lddc,
60 queue);
61 }
62 return 0;
63 }
64
65 ////////////////////////////////////////////////////////////////////////////////
devblas_gemm_batched_strided(magma_trans_t trans_A,magma_trans_t trans_B,magma_int_t m,magma_int_t n,magma_int_t k,CeedScalar alpha,const CeedScalar * d_A,magma_int_t ldda,magma_int_t strideA,const CeedScalar * d_B,magma_int_t lddb,magma_int_t strideB,CeedScalar beta,CeedScalar * d_C,magma_int_t lddc,magma_int_t strideC,magma_int_t batchCount,magma_queue_t queue)66 static inline int devblas_gemm_batched_strided(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k,
67 CeedScalar alpha, const CeedScalar *d_A, magma_int_t ldda, magma_int_t strideA, const CeedScalar *d_B,
68 magma_int_t lddb, magma_int_t strideB, CeedScalar beta, CeedScalar *d_C, magma_int_t lddc,
69 magma_int_t strideC, magma_int_t batchCount, magma_queue_t queue) {
70 if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
71 devblasSgemmStridedBatched(magma_queue_get_devblas_handle(queue), devblas_trans_const(trans_A), devblas_trans_const(trans_B), (int)m, (int)n,
72 (int)k, (const float *)&alpha, (const float *)d_A, (int)ldda, strideA, (const float *)d_B, (int)lddb, strideB,
73 (const float *)&beta, (float *)d_C, (int)lddc, strideC, (int)batchCount);
74 } else {
75 devblasDgemmStridedBatched(magma_queue_get_devblas_handle(queue), devblas_trans_const(trans_A), devblas_trans_const(trans_B), (int)m, (int)n,
76 (int)k, (const double *)&alpha, (const double *)d_A, (int)ldda, strideA, (const double *)d_B, (int)lddb, strideB,
77 (const double *)&beta, (double *)d_C, (int)lddc, strideC, (int)batchCount);
78 }
79 return 0;
80 }
81
82 ////////////////////////////////////////////////////////////////////////////////
magma_gemm_nontensor(magma_trans_t trans_A,magma_trans_t trans_B,magma_int_t m,magma_int_t n,magma_int_t k,CeedScalar alpha,const CeedScalar * d_A,magma_int_t ldda,const CeedScalar * d_B,magma_int_t lddb,CeedScalar beta,CeedScalar * d_C,magma_int_t lddc,magma_queue_t queue)83 int magma_gemm_nontensor(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha,
84 const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C,
85 magma_int_t lddc, magma_queue_t queue) {
86 magma_int_t nbatch, use_magmablas;
87 magma_int_t arch = magma_getdevice_arch();
88
89 // check for specific transpositions (NN and TN only)
90 bool NN = trans_A == MagmaNoTrans && trans_B == MagmaNoTrans;
91 bool TN = trans_A == MagmaTrans && trans_B == MagmaNoTrans;
92 if (!(NN || TN)) {
93 // default case -- no specific tuning
94 devblas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue);
95 return 0;
96 }
97
98 // get tuning decision
99 char trans = (trans_A == MagmaNoTrans) ? 'n' : 't';
100 char precision = (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) ? 's' : 'd';
101 gemm_selector(arch, precision, trans, m, n, k, &nbatch, &use_magmablas);
102
103 // perform the gemm operation
104 if (nbatch == n) {
105 // no batching
106 if (use_magmablas) {
107 magmablas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue);
108 } else {
109 devblas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue);
110 }
111 } else {
112 // use batch kernels
113 magma_int_t batchCount = n / nbatch;
114 magma_int_t n2 = n - (batchCount * nbatch);
115 magma_int_t strideA = 0;
116 magma_int_t strideB = lddb * nbatch;
117 magma_int_t strideC = lddc * nbatch;
118
119 if (use_magmablas) {
120 if (batchCount > 0) {
121 magmablas_gemm_batched_strided(trans_A, trans_B, m, nbatch, k, alpha, d_A, ldda, strideA, d_B, lddb, strideB, beta, d_C, lddc, strideC,
122 batchCount, queue);
123 }
124
125 // cleanup
126 if (n2 > 0) {
127 devblas_gemm(trans_A, trans_B, m, n2, k, alpha, d_A, ldda, d_B + batchCount * strideB, lddb, beta, d_C + batchCount * strideC, lddc, queue);
128 }
129 } else {
130 if (batchCount > 0) {
131 devblas_gemm_batched_strided(trans_A, trans_B, m, nbatch, k, alpha, d_A, ldda, strideA, d_B, lddb, strideB, beta, d_C, lddc, strideC,
132 batchCount, queue);
133 }
134
135 // cleanup
136 if (n2 > 0) {
137 devblas_gemm_batched_strided(trans_A, trans_B, m, n2, k, alpha, d_A, ldda, strideA, d_B + batchCount * strideB, lddb, strideB, beta,
138 d_C + batchCount * strideC, lddc, strideC, 1, queue);
139 }
140 }
141 }
142
143 // wait for the operation to complete
144 ceed_magma_queue_sync(queue);
145
146 return 0;
147 }
148