xref: /libCEED/rust/libceed-sys/c-src/backends/magma/ceed-magma-gemm-nontensor.cpp (revision 940a72f1a85a7fc8459dbc83c7f6f7637fe1955b)
1*940a72f1SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*940a72f1SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*940a72f1SSebastian Grimberg //
4*940a72f1SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5*940a72f1SSebastian Grimberg //
6*940a72f1SSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7*940a72f1SSebastian Grimberg 
8*940a72f1SSebastian Grimberg #include "ceed-magma-gemm-nontensor.h"
9*940a72f1SSebastian Grimberg 
10*940a72f1SSebastian Grimberg #include "ceed-magma-gemm-selector.h"
11*940a72f1SSebastian Grimberg 
12*940a72f1SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP
13*940a72f1SSebastian Grimberg #define devblasDgemmStridedBatched hipblasDgemmStridedBatched
14*940a72f1SSebastian Grimberg #define devblasSgemmStridedBatched hipblasSgemmStridedBatched
15*940a72f1SSebastian Grimberg #define magma_queue_get_devblas_handle magma_queue_get_hipblas_handle
16*940a72f1SSebastian Grimberg #define devblas_trans_const hipblas_trans_const
17*940a72f1SSebastian Grimberg #else
18*940a72f1SSebastian Grimberg #define devblasDgemmStridedBatched cublasDgemmStridedBatched
19*940a72f1SSebastian Grimberg #define devblasSgemmStridedBatched cublasSgemmStridedBatched
20*940a72f1SSebastian Grimberg #define magma_queue_get_devblas_handle magma_queue_get_cublas_handle
21*940a72f1SSebastian Grimberg #define devblas_trans_const cublas_trans_const
22*940a72f1SSebastian Grimberg #endif
23*940a72f1SSebastian Grimberg 
24*940a72f1SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
25*940a72f1SSebastian Grimberg static inline int magmablas_gemm(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha,
26*940a72f1SSebastian Grimberg                                  const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C,
27*940a72f1SSebastian Grimberg                                  magma_int_t lddc, magma_queue_t queue) {
28*940a72f1SSebastian Grimberg   if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
29*940a72f1SSebastian Grimberg     magmablas_sgemm(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, (const float *)d_B, lddb, (float)beta, (float *)d_C, lddc,
30*940a72f1SSebastian Grimberg                     queue);
31*940a72f1SSebastian Grimberg   } else {
32*940a72f1SSebastian Grimberg     magmablas_dgemm(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, (const double *)d_B, lddb, (double)beta, (double *)d_C, lddc,
33*940a72f1SSebastian Grimberg                     queue);
34*940a72f1SSebastian Grimberg   }
35*940a72f1SSebastian Grimberg   return 0;
36*940a72f1SSebastian Grimberg }
37*940a72f1SSebastian Grimberg 
38*940a72f1SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
39*940a72f1SSebastian Grimberg static inline int magmablas_gemm_batched_strided(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k,
40*940a72f1SSebastian Grimberg                                                  CeedScalar alpha, const CeedScalar *d_A, magma_int_t ldda, magma_int_t strideA,
41*940a72f1SSebastian Grimberg                                                  const CeedScalar *d_B, magma_int_t lddb, magma_int_t strideB, CeedScalar beta, CeedScalar *d_C,
42*940a72f1SSebastian Grimberg                                                  magma_int_t lddc, magma_int_t strideC, magma_int_t batchCount, magma_queue_t queue) {
43*940a72f1SSebastian Grimberg   if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
44*940a72f1SSebastian Grimberg     magmablas_sgemm_batched_strided(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, strideA, (const float *)d_B, lddb, strideB,
45*940a72f1SSebastian Grimberg                                     (float)beta, (float *)d_C, lddc, strideC, batchCount, queue);
46*940a72f1SSebastian Grimberg   } else {
47*940a72f1SSebastian Grimberg     magmablas_dgemm_batched_strided(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, strideA, (const double *)d_B, lddb, strideB,
48*940a72f1SSebastian Grimberg                                     (double)beta, (double *)d_C, lddc, strideC, batchCount, queue);
49*940a72f1SSebastian Grimberg   }
50*940a72f1SSebastian Grimberg   return 0;
51*940a72f1SSebastian Grimberg }
52*940a72f1SSebastian Grimberg 
53*940a72f1SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
54*940a72f1SSebastian Grimberg static inline int devblas_gemm(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha,
55*940a72f1SSebastian Grimberg                                const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C,
56*940a72f1SSebastian Grimberg                                magma_int_t lddc, magma_queue_t queue) {
57*940a72f1SSebastian Grimberg   if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
58*940a72f1SSebastian Grimberg     magma_sgemm(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, (const float *)d_B, lddb, (float)beta, (float *)d_C, lddc, queue);
59*940a72f1SSebastian Grimberg   } else {
60*940a72f1SSebastian Grimberg     magma_dgemm(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, (const double *)d_B, lddb, (double)beta, (double *)d_C, lddc,
61*940a72f1SSebastian Grimberg                 queue);
62*940a72f1SSebastian Grimberg   }
63*940a72f1SSebastian Grimberg   return 0;
64*940a72f1SSebastian Grimberg }
65*940a72f1SSebastian Grimberg 
66*940a72f1SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
67*940a72f1SSebastian Grimberg static inline int devblas_gemm_batched_strided(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k,
68*940a72f1SSebastian Grimberg                                                CeedScalar alpha, const CeedScalar *d_A, magma_int_t ldda, magma_int_t strideA, const CeedScalar *d_B,
69*940a72f1SSebastian Grimberg                                                magma_int_t lddb, magma_int_t strideB, CeedScalar beta, CeedScalar *d_C, magma_int_t lddc,
70*940a72f1SSebastian Grimberg                                                magma_int_t strideC, magma_int_t batchCount, magma_queue_t queue) {
71*940a72f1SSebastian Grimberg   if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
72*940a72f1SSebastian Grimberg     devblasSgemmStridedBatched(magma_queue_get_devblas_handle(queue), devblas_trans_const(trans_A), devblas_trans_const(trans_B), (int)m, (int)n,
73*940a72f1SSebastian Grimberg                                (int)k, (const float *)&alpha, (const float *)d_A, (int)ldda, strideA, (const float *)d_B, (int)lddb, strideB,
74*940a72f1SSebastian Grimberg                                (const float *)&beta, (float *)d_C, (int)lddc, strideC, (int)batchCount);
75*940a72f1SSebastian Grimberg   } else {
76*940a72f1SSebastian Grimberg     devblasDgemmStridedBatched(magma_queue_get_devblas_handle(queue), devblas_trans_const(trans_A), devblas_trans_const(trans_B), (int)m, (int)n,
77*940a72f1SSebastian Grimberg                                (int)k, (const double *)&alpha, (const double *)d_A, (int)ldda, strideA, (const double *)d_B, (int)lddb, strideB,
78*940a72f1SSebastian Grimberg                                (const double *)&beta, (double *)d_C, (int)lddc, strideC, (int)batchCount);
79*940a72f1SSebastian Grimberg   }
80*940a72f1SSebastian Grimberg   return 0;
81*940a72f1SSebastian Grimberg }
82*940a72f1SSebastian Grimberg 
83*940a72f1SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
84*940a72f1SSebastian Grimberg int magma_gemm_nontensor(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha,
85*940a72f1SSebastian Grimberg                          const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C,
86*940a72f1SSebastian Grimberg                          magma_int_t lddc, magma_queue_t queue) {
87*940a72f1SSebastian Grimberg   magma_int_t nbatch, use_magmablas;
88*940a72f1SSebastian Grimberg   magma_int_t arch = magma_getdevice_arch();
89*940a72f1SSebastian Grimberg 
90*940a72f1SSebastian Grimberg   // check for specific transpositions (NN and TN only)
91*940a72f1SSebastian Grimberg   bool NN = trans_A == MagmaNoTrans && trans_B == MagmaNoTrans;
92*940a72f1SSebastian Grimberg   bool TN = trans_A == MagmaTrans && trans_B == MagmaNoTrans;
93*940a72f1SSebastian Grimberg   if (!(NN || TN)) {
94*940a72f1SSebastian Grimberg     // default case -- no specific tuning
95*940a72f1SSebastian Grimberg     devblas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue);
96*940a72f1SSebastian Grimberg     return 0;
97*940a72f1SSebastian Grimberg   }
98*940a72f1SSebastian Grimberg 
99*940a72f1SSebastian Grimberg   // get tuning decision
100*940a72f1SSebastian Grimberg   char trans     = (trans_A == MagmaNoTrans) ? 'n' : 't';
101*940a72f1SSebastian Grimberg   char precision = (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) ? 's' : 'd';
102*940a72f1SSebastian Grimberg   gemm_selector(arch, precision, trans, m, n, k, &nbatch, &use_magmablas);
103*940a72f1SSebastian Grimberg 
104*940a72f1SSebastian Grimberg   // perform the gemm operation
105*940a72f1SSebastian Grimberg   if (nbatch == n) {
106*940a72f1SSebastian Grimberg     // no batching
107*940a72f1SSebastian Grimberg     if (use_magmablas) {
108*940a72f1SSebastian Grimberg       magmablas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue);
109*940a72f1SSebastian Grimberg     } else {
110*940a72f1SSebastian Grimberg       devblas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue);
111*940a72f1SSebastian Grimberg     }
112*940a72f1SSebastian Grimberg   } else {
113*940a72f1SSebastian Grimberg     // use batch kernels
114*940a72f1SSebastian Grimberg     magma_int_t batchCount = n / nbatch;
115*940a72f1SSebastian Grimberg     magma_int_t n2         = n - (batchCount * nbatch);
116*940a72f1SSebastian Grimberg     magma_int_t strideA    = 0;
117*940a72f1SSebastian Grimberg     magma_int_t strideB    = lddb * nbatch;
118*940a72f1SSebastian Grimberg     magma_int_t strideC    = lddc * nbatch;
119*940a72f1SSebastian Grimberg 
120*940a72f1SSebastian Grimberg     if (use_magmablas) {
121*940a72f1SSebastian Grimberg       if (batchCount > 0) {
122*940a72f1SSebastian Grimberg         magmablas_gemm_batched_strided(trans_A, trans_B, m, nbatch, k, alpha, d_A, ldda, strideA, d_B, lddb, strideB, beta, d_C, lddc, strideC,
123*940a72f1SSebastian Grimberg                                        batchCount, queue);
124*940a72f1SSebastian Grimberg       }
125*940a72f1SSebastian Grimberg 
126*940a72f1SSebastian Grimberg       // cleanup
127*940a72f1SSebastian Grimberg       if (n2 > 0) {
128*940a72f1SSebastian Grimberg         devblas_gemm(trans_A, trans_B, m, n2, k, alpha, d_A, ldda, d_B + batchCount * strideB, lddb, beta, d_C + batchCount * strideC, lddc, queue);
129*940a72f1SSebastian Grimberg       }
130*940a72f1SSebastian Grimberg     } else {
131*940a72f1SSebastian Grimberg       if (batchCount > 0) {
132*940a72f1SSebastian Grimberg         devblas_gemm_batched_strided(trans_A, trans_B, m, nbatch, k, alpha, d_A, ldda, strideA, d_B, lddb, strideB, beta, d_C, lddc, strideC,
133*940a72f1SSebastian Grimberg                                      batchCount, queue);
134*940a72f1SSebastian Grimberg       }
135*940a72f1SSebastian Grimberg 
136*940a72f1SSebastian Grimberg       // cleanup
137*940a72f1SSebastian Grimberg       if (n2 > 0) {
138*940a72f1SSebastian Grimberg         devblas_gemm_batched_strided(trans_A, trans_B, m, n2, k, alpha, d_A, ldda, strideA, d_B + batchCount * strideB, lddb, strideB, beta,
139*940a72f1SSebastian Grimberg                                      d_C + batchCount * strideC, lddc, strideC, 1, queue);
140*940a72f1SSebastian Grimberg       }
141*940a72f1SSebastian Grimberg     }
142*940a72f1SSebastian Grimberg   }
143*940a72f1SSebastian Grimberg 
144*940a72f1SSebastian Grimberg   // wait for the operation to complete
145*940a72f1SSebastian Grimberg   ceed_magma_queue_sync(queue);
146*940a72f1SSebastian Grimberg 
147*940a72f1SSebastian Grimberg   return 0;
148*940a72f1SSebastian Grimberg }
149