1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <array> 9 #include <limits> 10 #include <vector> 11 12 #include "ceed-magma-gemm-selector.h" 13 14 #include "tuning/indices.h" 15 #ifdef CEED_MAGMA_USE_HIP 16 #include "tuning/mi100.h" 17 #include "tuning/mi100_rtc.h" 18 #include "tuning/mi250x.h" 19 #include "tuning/mi250x_rtc.h" 20 #else 21 #include "tuning/a100.h" 22 #include "tuning/a100_rtc.h" 23 #include "tuning/h100_rtc.h" 24 #include "tuning/v100.h" 25 #include "tuning/v100_rtc.h" 26 #endif 27 28 // These definitions to force a certain parameter when generating autotuning data offline 29 // #define CEED_AUTOTUNE_GEMM_SELECTOR_N_BATCH 30 // #define CEED_AUTOTUNE_GEMM_SELECTOR_USE_MAGMA 31 // #define CEED_AUTOTUNE_RTC_NB 32 33 //////////////////////////////////////////////////////////////////////////////// 34 #ifdef CEED_MAGMA_USE_HIP 35 static inline auto gemm_selector_get_data(int gpu_arch, char precision, char trans_A) -> decltype(dgemm_nn_mi100) { 36 if (gpu_arch >= 910) { 37 // gfx90a or newer 38 return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_mi250x : sgemm_tn_mi250x) : ((trans_A == 'n') ? dgemm_nn_mi250x : dgemm_tn_mi250x); 39 } else { 40 // gfx908 or older 41 return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_mi100 : sgemm_tn_mi100) : ((trans_A == 'n') ? dgemm_nn_mi100 : dgemm_tn_mi100); 42 } 43 } 44 #else 45 static inline auto gemm_selector_get_data(int gpu_arch, char precision, char trans_A) -> decltype(dgemm_nn_v100) { 46 if (gpu_arch >= 800) { 47 // sm80 or newer 48 return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_a100 : sgemm_tn_a100) : ((trans_A == 'n') ? dgemm_nn_a100 : dgemm_tn_a100); 49 } else { 50 // sm70 or older 51 return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_v100 : sgemm_tn_v100) : ((trans_A == 'n') ? dgemm_nn_v100 : dgemm_tn_v100); 52 } 53 } 54 #endif 55 56 //////////////////////////////////////////////////////////////////////////////// 57 void gemm_selector(int gpu_arch, char precision, char trans_A, int m, int n, int k, int *n_batch, int *use_magma) { 58 #if defined(CEED_AUTOTUNE_GEMM_SELECTOR_N_BATCH) && defined(CEED_AUTOTUNE_GEMM_SELECTOR_USE_MAGMA) 59 *n_batch = CEED_AUTOTUNE_GEMM_SELECTOR_N_BATCH; 60 *use_magma = CEED_AUTOTUNE_GEMM_SELECTOR_USE_MAGMA; 61 #else 62 const auto &data = gemm_selector_get_data(gpu_arch, precision, trans_A); 63 int ir = -1; 64 double norm = std::numeric_limits<double>::max(); 65 66 for (size_t i = 0; i < data.size(); i++) { 67 const int &im = data[i][M_INDEX]; 68 const int &in = data[i][N_INDEX]; 69 const int &ik = data[i][K_INDEX]; 70 71 double mdiff = (double)(im - m); 72 double ndiff = (double)(in - n); 73 double kdiff = (double)(ik - k); 74 double nrm = mdiff * mdiff + ndiff * ndiff + kdiff * kdiff; 75 76 if (nrm < norm) { 77 norm = nrm; 78 ir = i; 79 } 80 81 if (im == m && in == n && ik == k) { 82 // The input (m, n, k) exactly matches a record in `data`, no need to search further 83 break; 84 } 85 } 86 87 if (ir >= 0) { 88 // If the closest match indicates that n = n_batch, that means calling the regular non-batch GEMM. 89 // So n_batch is set to n instead of the 'n_batch' entry of the matching record. 90 int n_ = data[ir][N_INDEX]; 91 int n_batch_ = data[ir][N_BATCH_INDEX]; 92 *n_batch = (n_ == n_batch_) ? n : n_batch_; 93 *use_magma = data[ir][USE_MAGMA_INDEX]; 94 } else { 95 *n_batch = n; 96 *use_magma = 0; 97 } 98 #endif 99 } 100 101 ////////////////////////////////////////////////////////////////////////////// 102 #ifdef CEED_MAGMA_USE_HIP 103 static inline auto nontensor_rtc_get_data(int gpu_arch, char trans_A) -> decltype(drtc_n_mi100) { 104 if (gpu_arch >= 910) { 105 // gfx90a or newer 106 return (trans_A == 'n') ? drtc_n_mi250x : drtc_t_mi250x; 107 } else { 108 // gfx908 or older 109 return (trans_A == 'n') ? drtc_n_mi100 : drtc_t_mi100; 110 } 111 } 112 #else 113 static inline auto nontensor_rtc_get_data(int gpu_arch, char trans_A) -> decltype(drtc_n_v100) { 114 if (gpu_arch >= 900) { 115 // sm90 or newer 116 return (trans_A == 'n') ? drtc_n_h100 : drtc_t_h100; 117 } else if (gpu_arch >= 800) { 118 // sm80 or newer 119 return (trans_A == 'n') ? drtc_n_a100 : drtc_t_a100; 120 } else { 121 // sm70 or older 122 return (trans_A == 'n') ? drtc_n_v100 : drtc_t_v100; 123 } 124 } 125 #endif 126 127 //////////////////////////////////////////////////////////////////////////////// 128 CeedInt nontensor_rtc_get_nb(int gpu_arch, char trans_A, int q_comp, int P, int Q, int N) { 129 #ifdef CEED_AUTOTUNE_RTC_NB 130 return CEED_AUTOTUNE_RTC_NB; 131 #else 132 const auto &data = nontensor_rtc_get_data(gpu_arch, trans_A); 133 int ir = -1; 134 double norm = std::numeric_limits<double>::max(); 135 136 for (size_t i = 0; i < data.size(); i++) { 137 // Only seach exact matches for q_comp 138 if (q_comp != data[i][Q_COMP_INDEX_RTC]) { 139 continue; 140 } 141 142 const int &iP = data[i][P_INDEX_RTC]; 143 const int &iQ = data[i][Q_INDEX_RTC]; 144 const int &iN = data[i][N_INDEX_RTC]; 145 146 double Pdiff = (double)(iP - P); 147 double Qdiff = (double)(iQ - Q); 148 double Ndiff = (double)(iN - N); 149 double nrm = Pdiff * Pdiff + Qdiff * Qdiff + Ndiff * Ndiff; 150 151 if (nrm < norm) { 152 norm = nrm; 153 ir = i; 154 } 155 156 if (iP == P && iQ == Q && iN == N) { 157 // The input (P, Q, N) exactly matches a record in `data`, no need to search further 158 break; 159 } 160 } 161 162 return (ir >= 0) ? data[ir][NB_INDEX_RTC] : 1; 163 #endif 164 } 165