1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include "ceed-magma-gemm-selector.h" 9 10 #include <array> 11 #include <limits> 12 #include <vector> 13 14 #include "tuning/indices.h" 15 #ifdef CEED_MAGMA_USE_HIP 16 #include "tuning/mi100.h" 17 #include "tuning/mi250x.h" 18 #include "tuning/mi250x_grad_rtc.h" 19 #include "tuning/mi250x_interp_rtc.h" 20 #else 21 #include "tuning/a100.h" 22 #include "tuning/a100_grad_rtc.h" 23 #include "tuning/a100_interp_rtc.h" 24 #include "tuning/v100.h" 25 #endif 26 27 //////////////////////////////////////////////////////////////////////////////// 28 #ifdef CEED_MAGMA_USE_HIP 29 static inline auto gemm_selector_get_data(int gpu_arch, char precision, char trans_A) -> decltype(dgemm_nn_mi250x) { 30 if (gpu_arch >= 910) { 31 // gfx90a or newer 32 return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_mi250x : sgemm_tn_mi250x) : ((trans_A == 'n') ? dgemm_nn_mi250x : dgemm_tn_mi250x); 33 } else { 34 // gfx908 or older 35 return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_mi100 : sgemm_tn_mi100) : ((trans_A == 'n') ? dgemm_nn_mi100 : dgemm_tn_mi100); 36 } 37 } 38 #else 39 static inline auto gemm_selector_get_data(int gpu_arch, char precision, char trans_A) -> decltype(dgemm_nn_a100) { 40 if (gpu_arch >= 800) { 41 // sm80 or newer 42 return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_a100 : sgemm_tn_a100) : ((trans_A == 'n') ? dgemm_nn_a100 : dgemm_tn_a100); 43 } else { 44 // sm70 or older 45 return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_v100 : sgemm_tn_v100) : ((trans_A == 'n') ? dgemm_nn_v100 : dgemm_tn_v100); 46 } 47 } 48 #endif 49 50 //////////////////////////////////////////////////////////////////////////////// 51 void gemm_selector(int gpu_arch, char precision, char trans_A, int m, int n, int k, int *n_batch, int *use_magma) { 52 const auto &data = gemm_selector_get_data(gpu_arch, precision, trans_A); 53 int ir = -1; 54 double norm = std::numeric_limits<double>::max(); 55 56 for (size_t i = 0; i < data.size(); i++) { 57 const int &im = data[i][M_INDEX]; 58 const int &in = data[i][N_INDEX]; 59 const int &ik = data[i][K_INDEX]; 60 61 double mdiff = (double)(im - m); 62 double ndiff = (double)(in - n); 63 double kdiff = (double)(ik - k); 64 double nrm = mdiff * mdiff + ndiff * ndiff + kdiff * kdiff; 65 66 if (nrm < norm) { 67 norm = nrm; 68 ir = i; 69 } 70 71 if (im == m && in == n && ik == k) { 72 // The input (m, n, k) exactly matches a record in `data`, no need to search further 73 break; 74 } 75 } 76 77 if (ir >= 0) { 78 // If the closest match indicates that n = n_batch, that means calling the regular non-batch GEMM. 79 // So n_batch is set to n instead of the 'n_batch' entry of the matching record. 80 int n_ = data[ir][N_INDEX]; 81 int n_batch_ = data[ir][N_BATCH_INDEX]; 82 *n_batch = (n_ == n_batch_) ? n : n_batch_; 83 *use_magma = data[ir][USE_MAGMA_INDEX]; 84 } else { 85 *n_batch = n; 86 *use_magma = 0; 87 } 88 } 89 90 ////////////////////////////////////////////////////////////////////////////// 91 #ifdef CEED_MAGMA_USE_HIP 92 static inline auto nontensor_rtc_get_data(int gpu_arch, char trans_A, int q_comp) -> decltype(dinterp_n_mi250x) { 93 if (q_comp == 1) { 94 return (trans_A == 'n') ? dinterp_n_mi250x : dinterp_t_mi250x; 95 } else { 96 return (trans_A == 'n') ? dgrad_n_mi250x : dgrad_t_mi250x; 97 } 98 } 99 #else 100 static inline auto nontensor_rtc_get_data(int gpu_arch, char trans_A, int q_comp) -> decltype(dinterp_n_a100) { 101 if (q_comp == 1) { 102 return (trans_A == 'n') ? dinterp_n_a100 : dinterp_t_a100; 103 } else { 104 return (trans_A == 'n') ? dgrad_n_a100 : dgrad_t_a100; 105 } 106 } 107 #endif 108 109 //////////////////////////////////////////////////////////////////////////////// 110 CeedInt nontensor_rtc_get_nb(int gpu_arch, char trans_A, int q_comp, int P, int Q, int n) { 111 const auto &data = nontensor_rtc_get_data(gpu_arch, trans_A, q_comp); 112 int ir = -1; 113 double norm = std::numeric_limits<double>::max(); 114 CeedInt m = (trans_A == 'n') ? Q : P; 115 CeedInt k = (trans_A == 'n') ? P : Q; 116 117 for (size_t i = 0; i < data.size(); i++) { 118 const int &im = data[i][M_INDEX_RTC]; 119 const int &in = data[i][N_INDEX_RTC]; 120 const int &ik = data[i][K_INDEX_RTC]; 121 122 double mdiff = (double)(im - m); 123 double ndiff = (double)(in - n); 124 double kdiff = (double)(ik - k); 125 double nrm = mdiff * mdiff + ndiff * ndiff + kdiff * kdiff; 126 127 if (nrm < norm) { 128 norm = nrm; 129 ir = i; 130 } 131 132 if (im == m && in == n && ik == k) { 133 // The input (m, n, k) exactly matches a record in `data`, no need to search further 134 break; 135 } 136 } 137 138 return (ir >= 0) ? data[ir][NB_INDEX_RTC] : 1; 139 } 140