1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED: http://github.com/ceed
7
8 #include <array>
9 #include <limits>
10 #include <vector>
11
12 #include "ceed-magma-gemm-selector.h"
13
14 #include "tuning/indices.h"
15 #ifdef CEED_MAGMA_USE_HIP
16 #include "tuning/mi100.h"
17 #include "tuning/mi100_rtc.h"
18 #include "tuning/mi250x.h"
19 #include "tuning/mi250x_rtc.h"
20 #else
21 #include "tuning/a100.h"
22 #include "tuning/a100_rtc.h"
23 #include "tuning/h100_rtc.h"
24 #include "tuning/v100.h"
25 #include "tuning/v100_rtc.h"
26 #endif
27
28 // These definitions to force a certain parameter when generating autotuning data offline
29 // #define CEED_AUTOTUNE_GEMM_SELECTOR_N_BATCH
30 // #define CEED_AUTOTUNE_GEMM_SELECTOR_USE_MAGMA
31 // #define CEED_AUTOTUNE_RTC_NB
32
33 ////////////////////////////////////////////////////////////////////////////////
34 #ifdef CEED_MAGMA_USE_HIP
gemm_selector_get_data(int gpu_arch,char precision,char trans_A)35 static inline auto gemm_selector_get_data(int gpu_arch, char precision, char trans_A) -> decltype(dgemm_nn_mi100) {
36 if (gpu_arch >= 910) {
37 // gfx90a or newer
38 return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_mi250x : sgemm_tn_mi250x) : ((trans_A == 'n') ? dgemm_nn_mi250x : dgemm_tn_mi250x);
39 } else {
40 // gfx908 or older
41 return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_mi100 : sgemm_tn_mi100) : ((trans_A == 'n') ? dgemm_nn_mi100 : dgemm_tn_mi100);
42 }
43 }
44 #else
gemm_selector_get_data(int gpu_arch,char precision,char trans_A)45 static inline auto gemm_selector_get_data(int gpu_arch, char precision, char trans_A) -> decltype(dgemm_nn_v100) {
46 if (gpu_arch >= 800) {
47 // sm80 or newer
48 return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_a100 : sgemm_tn_a100) : ((trans_A == 'n') ? dgemm_nn_a100 : dgemm_tn_a100);
49 } else {
50 // sm70 or older
51 return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_v100 : sgemm_tn_v100) : ((trans_A == 'n') ? dgemm_nn_v100 : dgemm_tn_v100);
52 }
53 }
54 #endif
55
56 ////////////////////////////////////////////////////////////////////////////////
gemm_selector(int gpu_arch,char precision,char trans_A,int m,int n,int k,int * n_batch,int * use_magma)57 void gemm_selector(int gpu_arch, char precision, char trans_A, int m, int n, int k, int *n_batch, int *use_magma) {
58 #if defined(CEED_AUTOTUNE_GEMM_SELECTOR_N_BATCH) && defined(CEED_AUTOTUNE_GEMM_SELECTOR_USE_MAGMA)
59 *n_batch = CEED_AUTOTUNE_GEMM_SELECTOR_N_BATCH;
60 *use_magma = CEED_AUTOTUNE_GEMM_SELECTOR_USE_MAGMA;
61 #else
62 const auto &data = gemm_selector_get_data(gpu_arch, precision, trans_A);
63 int ir = -1;
64 double norm = std::numeric_limits<double>::max();
65
66 for (size_t i = 0; i < data.size(); i++) {
67 const int &im = data[i][M_INDEX];
68 const int &in = data[i][N_INDEX];
69 const int &ik = data[i][K_INDEX];
70
71 double mdiff = (double)(im - m);
72 double ndiff = (double)(in - n);
73 double kdiff = (double)(ik - k);
74 double nrm = mdiff * mdiff + ndiff * ndiff + kdiff * kdiff;
75
76 if (nrm < norm) {
77 norm = nrm;
78 ir = i;
79 }
80
81 if (im == m && in == n && ik == k) {
82 // The input (m, n, k) exactly matches a record in `data`, no need to search further
83 break;
84 }
85 }
86
87 if (ir >= 0) {
88 // If the closest match indicates that n = n_batch, that means calling the regular non-batch GEMM.
89 // So n_batch is set to n instead of the 'n_batch' entry of the matching record.
90 int n_ = data[ir][N_INDEX];
91 int n_batch_ = data[ir][N_BATCH_INDEX];
92 *n_batch = (n_ == n_batch_) ? n : n_batch_;
93 *use_magma = data[ir][USE_MAGMA_INDEX];
94 } else {
95 *n_batch = n;
96 *use_magma = 0;
97 }
98 #endif
99 }
100
101 //////////////////////////////////////////////////////////////////////////////
102 #ifdef CEED_MAGMA_USE_HIP
nontensor_rtc_get_data(int gpu_arch,char trans_A)103 static inline auto nontensor_rtc_get_data(int gpu_arch, char trans_A) -> decltype(drtc_n_mi100) {
104 if (gpu_arch >= 910) {
105 // gfx90a or newer
106 return (trans_A == 'n') ? drtc_n_mi250x : drtc_t_mi250x;
107 } else {
108 // gfx908 or older
109 return (trans_A == 'n') ? drtc_n_mi100 : drtc_t_mi100;
110 }
111 }
112 #else
nontensor_rtc_get_data(int gpu_arch,char trans_A)113 static inline auto nontensor_rtc_get_data(int gpu_arch, char trans_A) -> decltype(drtc_n_v100) {
114 if (gpu_arch >= 900) {
115 // sm90 or newer
116 return (trans_A == 'n') ? drtc_n_h100 : drtc_t_h100;
117 } else if (gpu_arch >= 800) {
118 // sm80 or newer
119 return (trans_A == 'n') ? drtc_n_a100 : drtc_t_a100;
120 } else {
121 // sm70 or older
122 return (trans_A == 'n') ? drtc_n_v100 : drtc_t_v100;
123 }
124 }
125 #endif
126
127 ////////////////////////////////////////////////////////////////////////////////
nontensor_rtc_get_nb(int gpu_arch,char trans_A,int q_comp,int P,int Q,int N)128 CeedInt nontensor_rtc_get_nb(int gpu_arch, char trans_A, int q_comp, int P, int Q, int N) {
129 #ifdef CEED_AUTOTUNE_RTC_NB
130 return CEED_AUTOTUNE_RTC_NB;
131 #else
132 const auto &data = nontensor_rtc_get_data(gpu_arch, trans_A);
133 int ir = -1;
134 double norm = std::numeric_limits<double>::max();
135
136 for (size_t i = 0; i < data.size(); i++) {
137 // Only seach exact matches for q_comp
138 if (q_comp != data[i][Q_COMP_INDEX_RTC]) {
139 continue;
140 }
141
142 const int &iP = data[i][P_INDEX_RTC];
143 const int &iQ = data[i][Q_INDEX_RTC];
144 const int &iN = data[i][N_INDEX_RTC];
145
146 double Pdiff = (double)(iP - P);
147 double Qdiff = (double)(iQ - Q);
148 double Ndiff = (double)(iN - N);
149 double nrm = Pdiff * Pdiff + Qdiff * Qdiff + Ndiff * Ndiff;
150
151 if (nrm < norm) {
152 norm = nrm;
153 ir = i;
154 }
155
156 if (iP == P && iQ == Q && iN == N) {
157 // The input (P, Q, N) exactly matches a record in `data`, no need to search further
158 break;
159 }
160 }
161
162 return (ir >= 0) ? data[ir][NB_INDEX_RTC] : 1;
163 #endif
164 }
165