xref: /libCEED/rust/libceed-sys/c-src/backends/magma/ceed-magma-gemm-nontensor.cpp (revision 940a72f1a85a7fc8459dbc83c7f6f7637fe1955b)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include "ceed-magma-gemm-nontensor.h"
9 
10 #include "ceed-magma-gemm-selector.h"
11 
12 #ifdef CEED_MAGMA_USE_HIP
13 #define devblasDgemmStridedBatched hipblasDgemmStridedBatched
14 #define devblasSgemmStridedBatched hipblasSgemmStridedBatched
15 #define magma_queue_get_devblas_handle magma_queue_get_hipblas_handle
16 #define devblas_trans_const hipblas_trans_const
17 #else
18 #define devblasDgemmStridedBatched cublasDgemmStridedBatched
19 #define devblasSgemmStridedBatched cublasSgemmStridedBatched
20 #define magma_queue_get_devblas_handle magma_queue_get_cublas_handle
21 #define devblas_trans_const cublas_trans_const
22 #endif
23 
24 ////////////////////////////////////////////////////////////////////////////////
25 static inline int magmablas_gemm(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha,
26                                  const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C,
27                                  magma_int_t lddc, magma_queue_t queue) {
28   if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
29     magmablas_sgemm(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, (const float *)d_B, lddb, (float)beta, (float *)d_C, lddc,
30                     queue);
31   } else {
32     magmablas_dgemm(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, (const double *)d_B, lddb, (double)beta, (double *)d_C, lddc,
33                     queue);
34   }
35   return 0;
36 }
37 
38 ////////////////////////////////////////////////////////////////////////////////
39 static inline int magmablas_gemm_batched_strided(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k,
40                                                  CeedScalar alpha, const CeedScalar *d_A, magma_int_t ldda, magma_int_t strideA,
41                                                  const CeedScalar *d_B, magma_int_t lddb, magma_int_t strideB, CeedScalar beta, CeedScalar *d_C,
42                                                  magma_int_t lddc, magma_int_t strideC, magma_int_t batchCount, magma_queue_t queue) {
43   if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
44     magmablas_sgemm_batched_strided(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, strideA, (const float *)d_B, lddb, strideB,
45                                     (float)beta, (float *)d_C, lddc, strideC, batchCount, queue);
46   } else {
47     magmablas_dgemm_batched_strided(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, strideA, (const double *)d_B, lddb, strideB,
48                                     (double)beta, (double *)d_C, lddc, strideC, batchCount, queue);
49   }
50   return 0;
51 }
52 
53 ////////////////////////////////////////////////////////////////////////////////
54 static inline int devblas_gemm(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha,
55                                const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C,
56                                magma_int_t lddc, magma_queue_t queue) {
57   if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
58     magma_sgemm(trans_A, trans_B, m, n, k, (float)alpha, (const float *)d_A, ldda, (const float *)d_B, lddb, (float)beta, (float *)d_C, lddc, queue);
59   } else {
60     magma_dgemm(trans_A, trans_B, m, n, k, (double)alpha, (const double *)d_A, ldda, (const double *)d_B, lddb, (double)beta, (double *)d_C, lddc,
61                 queue);
62   }
63   return 0;
64 }
65 
66 ////////////////////////////////////////////////////////////////////////////////
67 static inline int devblas_gemm_batched_strided(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k,
68                                                CeedScalar alpha, const CeedScalar *d_A, magma_int_t ldda, magma_int_t strideA, const CeedScalar *d_B,
69                                                magma_int_t lddb, magma_int_t strideB, CeedScalar beta, CeedScalar *d_C, magma_int_t lddc,
70                                                magma_int_t strideC, magma_int_t batchCount, magma_queue_t queue) {
71   if (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) {
72     devblasSgemmStridedBatched(magma_queue_get_devblas_handle(queue), devblas_trans_const(trans_A), devblas_trans_const(trans_B), (int)m, (int)n,
73                                (int)k, (const float *)&alpha, (const float *)d_A, (int)ldda, strideA, (const float *)d_B, (int)lddb, strideB,
74                                (const float *)&beta, (float *)d_C, (int)lddc, strideC, (int)batchCount);
75   } else {
76     devblasDgemmStridedBatched(magma_queue_get_devblas_handle(queue), devblas_trans_const(trans_A), devblas_trans_const(trans_B), (int)m, (int)n,
77                                (int)k, (const double *)&alpha, (const double *)d_A, (int)ldda, strideA, (const double *)d_B, (int)lddb, strideB,
78                                (const double *)&beta, (double *)d_C, (int)lddc, strideC, (int)batchCount);
79   }
80   return 0;
81 }
82 
83 ////////////////////////////////////////////////////////////////////////////////
84 int magma_gemm_nontensor(magma_trans_t trans_A, magma_trans_t trans_B, magma_int_t m, magma_int_t n, magma_int_t k, CeedScalar alpha,
85                          const CeedScalar *d_A, magma_int_t ldda, const CeedScalar *d_B, magma_int_t lddb, CeedScalar beta, CeedScalar *d_C,
86                          magma_int_t lddc, magma_queue_t queue) {
87   magma_int_t nbatch, use_magmablas;
88   magma_int_t arch = magma_getdevice_arch();
89 
90   // check for specific transpositions (NN and TN only)
91   bool NN = trans_A == MagmaNoTrans && trans_B == MagmaNoTrans;
92   bool TN = trans_A == MagmaTrans && trans_B == MagmaNoTrans;
93   if (!(NN || TN)) {
94     // default case -- no specific tuning
95     devblas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue);
96     return 0;
97   }
98 
99   // get tuning decision
100   char trans     = (trans_A == MagmaNoTrans) ? 'n' : 't';
101   char precision = (CEED_SCALAR_TYPE == CEED_SCALAR_FP32) ? 's' : 'd';
102   gemm_selector(arch, precision, trans, m, n, k, &nbatch, &use_magmablas);
103 
104   // perform the gemm operation
105   if (nbatch == n) {
106     // no batching
107     if (use_magmablas) {
108       magmablas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue);
109     } else {
110       devblas_gemm(trans_A, trans_B, m, n, k, alpha, d_A, ldda, d_B, lddb, beta, d_C, lddc, queue);
111     }
112   } else {
113     // use batch kernels
114     magma_int_t batchCount = n / nbatch;
115     magma_int_t n2         = n - (batchCount * nbatch);
116     magma_int_t strideA    = 0;
117     magma_int_t strideB    = lddb * nbatch;
118     magma_int_t strideC    = lddc * nbatch;
119 
120     if (use_magmablas) {
121       if (batchCount > 0) {
122         magmablas_gemm_batched_strided(trans_A, trans_B, m, nbatch, k, alpha, d_A, ldda, strideA, d_B, lddb, strideB, beta, d_C, lddc, strideC,
123                                        batchCount, queue);
124       }
125 
126       // cleanup
127       if (n2 > 0) {
128         devblas_gemm(trans_A, trans_B, m, n2, k, alpha, d_A, ldda, d_B + batchCount * strideB, lddb, beta, d_C + batchCount * strideC, lddc, queue);
129       }
130     } else {
131       if (batchCount > 0) {
132         devblas_gemm_batched_strided(trans_A, trans_B, m, nbatch, k, alpha, d_A, ldda, strideA, d_B, lddb, strideB, beta, d_C, lddc, strideC,
133                                      batchCount, queue);
134       }
135 
136       // cleanup
137       if (n2 > 0) {
138         devblas_gemm_batched_strided(trans_A, trans_B, m, n2, k, alpha, d_A, ldda, strideA, d_B + batchCount * strideB, lddb, strideB, beta,
139                                      d_C + batchCount * strideC, lddc, strideC, 1, queue);
140       }
141     }
142   }
143 
144   // wait for the operation to complete
145   ceed_magma_queue_sync(queue);
146 
147   return 0;
148 }
149