1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <array> 9 #include <limits> 10 #include <vector> 11 12 #include "ceed-magma-gemm-selector.h" 13 14 #include "tuning/indices.h" 15 #ifdef CEED_MAGMA_USE_HIP 16 #include "tuning/mi100.h" 17 #include "tuning/mi250x.h" 18 #include "tuning/mi250x_grad_rtc.h" 19 #include "tuning/mi250x_interp_rtc.h" 20 #else 21 #include "tuning/a100.h" 22 #include "tuning/a100_grad_rtc.h" 23 #include "tuning/a100_interp_rtc.h" 24 #include "tuning/v100.h" 25 #endif 26 27 // These definitions to force a certain parameter when generating autotuning data offline 28 // #define CEED_AUTOTUNE_GEMM_SELECTOR_N_BATCH 1 29 // #define CEED_AUTOTUNE_GEMM_SELECTOR_USE_MAGMA true 30 // #define CEED_AUTOTUNE_RTC_NB 1 31 32 //////////////////////////////////////////////////////////////////////////////// 33 #ifdef CEED_MAGMA_USE_HIP 34 static inline auto gemm_selector_get_data(int gpu_arch, char precision, char trans_A) -> decltype(dgemm_nn_mi250x) { 35 if (gpu_arch >= 910) { 36 // gfx90a or newer 37 return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_mi250x : sgemm_tn_mi250x) : ((trans_A == 'n') ? dgemm_nn_mi250x : dgemm_tn_mi250x); 38 } else { 39 // gfx908 or older 40 return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_mi100 : sgemm_tn_mi100) : ((trans_A == 'n') ? dgemm_nn_mi100 : dgemm_tn_mi100); 41 } 42 } 43 #else 44 static inline auto gemm_selector_get_data(int gpu_arch, char precision, char trans_A) -> decltype(dgemm_nn_a100) { 45 if (gpu_arch >= 800) { 46 // sm80 or newer 47 return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_a100 : sgemm_tn_a100) : ((trans_A == 'n') ? dgemm_nn_a100 : dgemm_tn_a100); 48 } else { 49 // sm70 or older 50 return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_v100 : sgemm_tn_v100) : ((trans_A == 'n') ? dgemm_nn_v100 : dgemm_tn_v100); 51 } 52 } 53 #endif 54 55 //////////////////////////////////////////////////////////////////////////////// 56 void gemm_selector(int gpu_arch, char precision, char trans_A, int m, int n, int k, int *n_batch, int *use_magma) { 57 #if defined(CEED_AUTOTUNE_GEMM_SELECTOR_N_BATCH) && defined(CEED_AUTOTUNE_GEMM_SELECTOR_USE_MAGMA) 58 *n_batch = CEED_AUTOTUNE_GEMM_SELECTOR_N_BATCH; 59 *use_magma = CEED_AUTOTUNE_GEMM_SELECTOR_USE_MAGMA; 60 #else 61 const auto &data = gemm_selector_get_data(gpu_arch, precision, trans_A); 62 int ir = -1; 63 double norm = std::numeric_limits<double>::max(); 64 65 for (size_t i = 0; i < data.size(); i++) { 66 const int &im = data[i][M_INDEX]; 67 const int &in = data[i][N_INDEX]; 68 const int &ik = data[i][K_INDEX]; 69 70 double mdiff = (double)(im - m); 71 double ndiff = (double)(in - n); 72 double kdiff = (double)(ik - k); 73 double nrm = mdiff * mdiff + ndiff * ndiff + kdiff * kdiff; 74 75 if (nrm < norm) { 76 norm = nrm; 77 ir = i; 78 } 79 80 if (im == m && in == n && ik == k) { 81 // The input (m, n, k) exactly matches a record in `data`, no need to search further 82 break; 83 } 84 } 85 86 if (ir >= 0) { 87 // If the closest match indicates that n = n_batch, that means calling the regular non-batch GEMM. 88 // So n_batch is set to n instead of the 'n_batch' entry of the matching record. 89 int n_ = data[ir][N_INDEX]; 90 int n_batch_ = data[ir][N_BATCH_INDEX]; 91 *n_batch = (n_ == n_batch_) ? n : n_batch_; 92 *use_magma = data[ir][USE_MAGMA_INDEX]; 93 } else { 94 *n_batch = n; 95 *use_magma = 0; 96 } 97 #endif 98 } 99 100 ////////////////////////////////////////////////////////////////////////////// 101 #ifdef CEED_MAGMA_USE_HIP 102 static inline auto nontensor_rtc_get_data(int gpu_arch, char trans_A, int q_comp) -> decltype(dinterp_n_mi250x) { 103 if (q_comp == 1) { 104 return (trans_A == 'n') ? dinterp_n_mi250x : dinterp_t_mi250x; 105 } else { 106 return (trans_A == 'n') ? dgrad_n_mi250x : dgrad_t_mi250x; 107 } 108 } 109 #else 110 static inline auto nontensor_rtc_get_data(int gpu_arch, char trans_A, int q_comp) -> decltype(dinterp_n_a100) { 111 if (q_comp == 1) { 112 return (trans_A == 'n') ? dinterp_n_a100 : dinterp_t_a100; 113 } else { 114 return (trans_A == 'n') ? dgrad_n_a100 : dgrad_t_a100; 115 } 116 } 117 #endif 118 119 //////////////////////////////////////////////////////////////////////////////// 120 CeedInt nontensor_rtc_get_nb(int gpu_arch, char trans_A, int q_comp, int P, int Q, int n) { 121 #ifdef CEED_AUTOTUNE_RTC_NB 122 return CEED_AUTOTUNE_RTC_NB; 123 #else 124 const auto &data = nontensor_rtc_get_data(gpu_arch, trans_A, q_comp); 125 int ir = -1; 126 double norm = std::numeric_limits<double>::max(); 127 CeedInt m = (trans_A == 'n') ? Q : P; 128 CeedInt k = (trans_A == 'n') ? P : Q; 129 130 for (size_t i = 0; i < data.size(); i++) { 131 const int &im = data[i][M_INDEX_RTC]; 132 const int &in = data[i][N_INDEX_RTC]; 133 const int &ik = data[i][K_INDEX_RTC]; 134 135 double mdiff = (double)(im - m); 136 double ndiff = (double)(in - n); 137 double kdiff = (double)(ik - k); 138 double nrm = mdiff * mdiff + ndiff * ndiff + kdiff * kdiff; 139 140 if (nrm < norm) { 141 norm = nrm; 142 ir = i; 143 } 144 145 if (im == m && in == n && ik == k) { 146 // The input (m, n, k) exactly matches a record in `data`, no need to search further 147 break; 148 } 149 } 150 151 return (ir >= 0) ? data[ir][NB_INDEX_RTC] : 1; 152 #endif 153 } 154