xref: /libCEED/rust/libceed-sys/c-src/backends/magma/ceed-magma-gemm-selector.cpp (revision 7c7f2ed8989684d57a66d52f0d039c774f93b471)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <array>
9 #include <limits>
10 #include <vector>
11 
12 #include "ceed-magma-gemm-selector.h"
13 
14 #include "tuning/indices.h"
15 #ifdef CEED_MAGMA_USE_HIP
16 #include "tuning/mi100.h"
17 #include "tuning/mi100_rtc.h"
18 #include "tuning/mi250x.h"
19 #include "tuning/mi250x_rtc.h"
20 #else
21 #include "tuning/a100.h"
22 #include "tuning/a100_rtc.h"
23 #include "tuning/v100.h"
24 #include "tuning/v100_rtc.h"
25 #endif
26 
27 // These definitions to force a certain parameter when generating autotuning data offline
28 // #define CEED_AUTOTUNE_GEMM_SELECTOR_N_BATCH
29 // #define CEED_AUTOTUNE_GEMM_SELECTOR_USE_MAGMA
30 // #define CEED_AUTOTUNE_RTC_NB
31 
32 ////////////////////////////////////////////////////////////////////////////////
33 #ifdef CEED_MAGMA_USE_HIP
34 static inline auto gemm_selector_get_data(int gpu_arch, char precision, char trans_A) -> decltype(dgemm_nn_mi250x) {
35   if (gpu_arch >= 910) {
36     // gfx90a or newer
37     return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_mi250x : sgemm_tn_mi250x) : ((trans_A == 'n') ? dgemm_nn_mi250x : dgemm_tn_mi250x);
38   } else {
39     // gfx908 or older
40     return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_mi100 : sgemm_tn_mi100) : ((trans_A == 'n') ? dgemm_nn_mi100 : dgemm_tn_mi100);
41   }
42 }
43 #else
44 static inline auto gemm_selector_get_data(int gpu_arch, char precision, char trans_A) -> decltype(dgemm_nn_a100) {
45   if (gpu_arch >= 800) {
46     // sm80 or newer
47     return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_a100 : sgemm_tn_a100) : ((trans_A == 'n') ? dgemm_nn_a100 : dgemm_tn_a100);
48   } else {
49     // sm70 or older
50     return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_v100 : sgemm_tn_v100) : ((trans_A == 'n') ? dgemm_nn_v100 : dgemm_tn_v100);
51   }
52 }
53 #endif
54 
55 ////////////////////////////////////////////////////////////////////////////////
56 void gemm_selector(int gpu_arch, char precision, char trans_A, int m, int n, int k, int *n_batch, int *use_magma) {
57 #if defined(CEED_AUTOTUNE_GEMM_SELECTOR_N_BATCH) && defined(CEED_AUTOTUNE_GEMM_SELECTOR_USE_MAGMA)
58   *n_batch   = CEED_AUTOTUNE_GEMM_SELECTOR_N_BATCH;
59   *use_magma = CEED_AUTOTUNE_GEMM_SELECTOR_USE_MAGMA;
60 #else
61   const auto &data = gemm_selector_get_data(gpu_arch, precision, trans_A);
62   int         ir   = -1;
63   double      norm = std::numeric_limits<double>::max();
64 
65   for (size_t i = 0; i < data.size(); i++) {
66     const int &im = data[i][M_INDEX];
67     const int &in = data[i][N_INDEX];
68     const int &ik = data[i][K_INDEX];
69 
70     double mdiff = (double)(im - m);
71     double ndiff = (double)(in - n);
72     double kdiff = (double)(ik - k);
73     double nrm   = mdiff * mdiff + ndiff * ndiff + kdiff * kdiff;
74 
75     if (nrm < norm) {
76       norm = nrm;
77       ir   = i;
78     }
79 
80     if (im == m && in == n && ik == k) {
81       // The input (m, n, k) exactly matches a record in `data`, no need to search further
82       break;
83     }
84   }
85 
86   if (ir >= 0) {
87     // If the closest match indicates that n = n_batch, that means calling the regular non-batch GEMM.
88     // So n_batch is set to n instead of the 'n_batch' entry of the matching record.
89     int n_       = data[ir][N_INDEX];
90     int n_batch_ = data[ir][N_BATCH_INDEX];
91     *n_batch     = (n_ == n_batch_) ? n : n_batch_;
92     *use_magma   = data[ir][USE_MAGMA_INDEX];
93   } else {
94     *n_batch   = n;
95     *use_magma = 0;
96   }
97 #endif
98 }
99 
100 //////////////////////////////////////////////////////////////////////////////
101 #ifdef CEED_MAGMA_USE_HIP
102 static inline auto nontensor_rtc_get_data(int gpu_arch, char trans_A) -> decltype(drtc_n_mi250x) {
103   if (gpu_arch >= 910) {
104     // gfx90a or newer
105     return (trans_A == 'n') ? drtc_n_mi250x : drtc_t_mi250x;
106   } else {
107     // gfx908 or older
108     return (trans_A == 'n') ? drtc_n_mi100 : drtc_t_mi100;
109   }
110 }
111 #else
112 static inline auto nontensor_rtc_get_data(int gpu_arch, char trans_A) -> decltype(drtc_n_a100) {
113   if (gpu_arch >= 800) {
114     // sm80 or newer
115     return (trans_A == 'n') ? drtc_n_a100 : drtc_t_a100;
116   } else {
117     // sm70 or older
118     return (trans_A == 'n') ? drtc_n_v100 : drtc_t_v100;
119   }
120 }
121 #endif
122 
123 ////////////////////////////////////////////////////////////////////////////////
124 CeedInt nontensor_rtc_get_nb(int gpu_arch, char trans_A, int q_comp, int P, int Q, int N) {
125 #ifdef CEED_AUTOTUNE_RTC_NB
126   return CEED_AUTOTUNE_RTC_NB;
127 #else
128   const auto &data = nontensor_rtc_get_data(gpu_arch, trans_A);
129   int         ir   = -1;
130   double      norm = std::numeric_limits<double>::max();
131 
132   for (size_t i = 0; i < data.size(); i++) {
133     // Only seach exact matches for q_comp
134     if (q_comp != data[i][Q_COMP_INDEX_RTC]) {
135       continue;
136     }
137 
138     const int &iP = data[i][P_INDEX_RTC];
139     const int &iQ = data[i][Q_INDEX_RTC];
140     const int &iN = data[i][N_INDEX_RTC];
141 
142     double Pdiff = (double)(iP - P);
143     double Qdiff = (double)(iQ - Q);
144     double Ndiff = (double)(iN - N);
145     double nrm   = Pdiff * Pdiff + Qdiff * Qdiff + Ndiff * Ndiff;
146 
147     if (nrm < norm) {
148       norm = nrm;
149       ir   = i;
150     }
151 
152     if (iP == P && iQ == Q && iN == N) {
153       // The input (P, Q, N) exactly matches a record in `data`, no need to search further
154       break;
155     }
156   }
157 
158   return (ir >= 0) ? data[ir][NB_INDEX_RTC] : 1;
159 #endif
160 }
161