1 #include <stdio.h> 2 #include <sys/time.h> 3 4 #include <array> 5 #include <limits> 6 #include <vector> 7 8 #include "ceed-magma.h" 9 #include "tuning/indices.h" 10 #ifdef CEED_MAGMA_USE_HIP 11 #include "tuning/mi100.h" 12 #include "tuning/mi250x.h" 13 #include "tuning/mi250x_grad_rtc.h" 14 #include "tuning/mi250x_interp_rtc.h" 15 #else 16 #include "tuning/a100.h" 17 #include "tuning/a100_grad_rtc.h" 18 #include "tuning/a100_interp_rtc.h" 19 #include "tuning/v100.h" 20 #endif 21 22 //////////////////////////////////////////////////////////////////////////////// 23 static void *gemm_selector_get_data(int gpu_arch, char precision, char trans_A) { 24 // a default 25 #ifdef CEED_MAGMA_USE_HIP 26 void *data = (void *)&sgemm_nn_mi250x; 27 #else 28 void *data = (void *)&sgemm_nn_a100; 29 #endif 30 31 #ifdef CEED_MAGMA_USE_HIP 32 if (gpu_arch >= 910) { 33 // gfx90a or newer 34 data = (precision == 's') ? ((trans_A == 'n') ? (void *)&sgemm_nn_mi250x : (void *)&sgemm_tn_mi250x) 35 : ((trans_A == 'n') ? (void *)&dgemm_nn_mi250x : (void *)&dgemm_tn_mi250x); 36 } else { 37 // gfx908 or older 38 data = (precision == 's') ? ((trans_A == 'n') ? (void *)&sgemm_nn_mi100 : (void *)&sgemm_tn_mi100) 39 : ((trans_A == 'n') ? (void *)&dgemm_nn_mi100 : (void *)&dgemm_tn_mi100); 40 } 41 #else 42 if (gpu_arch >= 800) { 43 // sm80 or newer 44 data = (precision == 's') ? ((trans_A == 'n') ? (void *)&sgemm_nn_a100 : (void *)&sgemm_tn_a100) 45 : ((trans_A == 'n') ? (void *)&dgemm_nn_a100 : (void *)&dgemm_tn_a100); 46 } else { 47 // sm70 or older 48 data = (precision == 's') ? ((trans_A == 'n') ? (void *)&sgemm_nn_v100 : (void *)&sgemm_tn_v100) 49 : ((trans_A == 'n') ? (void *)&dgemm_nn_v100 : (void *)&dgemm_tn_v100); 50 } 51 #endif 52 53 return data; 54 } 55 56 //////////////////////////////////////////////////////////////////////////////// 57 void gemm_selector(int gpu_arch, char precision, char trans_A, int m, int n, int k, int *n_batch, int *use_magma) { 58 // defaults 59 *n_batch = n; 60 *use_magma = 0; 61 std::vector<std::array<int, RECORD_LENGTH> > *data = NULL; 62 data = (std::vector<std::array<int, RECORD_LENGTH> > *)gemm_selector_get_data(gpu_arch, precision, trans_A); 63 64 int ir = -1; 65 double norm = std::numeric_limits<double>::max(); 66 for (size_t i = 0; i < data->size(); i++) { 67 int im = (*data)[i][M_INDEX]; 68 int in = (*data)[i][N_INDEX]; 69 int ik = (*data)[i][K_INDEX]; 70 71 double mdiff = (double)(im - m); 72 double ndiff = (double)(in - n); 73 double kdiff = (double)(ik - k); 74 75 double nrm = sqrt(mdiff * mdiff + ndiff * ndiff + kdiff * kdiff); 76 77 if (nrm < norm) { 78 norm = nrm; 79 ir = i; 80 } 81 82 if (nrm == 0) { 83 // the input (m, n, k) exactly matches a record in `data` 84 // no need to search further 85 break; 86 } 87 } 88 89 if (ir >= 0) { 90 *use_magma = (*data)[ir][USE_MAGMA_INDEX]; 91 // if the closest match indicates that n = n_batch, 92 // that means calling the regular non-batch gemm. 93 // So n_batch is set to n instead of the 'n_batch' 94 // entry of the matching record 95 int n_ = (*data)[ir][N_INDEX]; 96 int n_batch_ = (*data)[ir][N_BATCH_INDEX]; 97 *n_batch = (n_ == n_batch_) ? n : n_batch_; 98 } 99 } 100 101 //////////////////////////////////////////////////////////////////////////////// 102 static void *nontensor_rtc_get_data(int gpu_arch, char precision, CeedEvalMode e_mode, CeedTransposeMode t_mode) { 103 // a default 104 #ifdef CEED_MAGMA_USE_HIP 105 void *data = (void *)&dinterp_n_mi250x; 106 #else 107 void *data = (void *)&dinterp_n_a100; 108 #endif 109 110 #ifdef CEED_MAGMA_USE_HIP 111 if (e_mode == CEED_EVAL_INTERP) { 112 data = (t_mode == CEED_TRANSPOSE) ? (void *)&dinterp_t_mi250x : (void *)&dinterp_n_mi250x; 113 } else if (e_mode == CEED_EVAL_GRAD) { 114 data = (t_mode == CEED_TRANSPOSE) ? (void *)&dgrad_t_mi250x : (void *)&dgrad_n_mi250x; 115 } 116 #else 117 if (e_mode == CEED_EVAL_INTERP) { 118 data = (t_mode == CEED_TRANSPOSE) ? (void *)&dinterp_t_a100 : (void *)&dinterp_n_a100; 119 } else if (e_mode == CEED_EVAL_GRAD) { 120 data = (t_mode == CEED_TRANSPOSE) ? (void *)&dgrad_t_a100 : (void *)&dgrad_n_a100; 121 } 122 #endif 123 124 return data; 125 } 126 127 //////////////////////////////////////////////////////////////////////////////// 128 CeedInt nontensor_rtc_get_nb(int gpu_arch, char precision, CeedEvalMode e_mode, CeedTransposeMode t_mode, int P_, int N, int Q_) { 129 CeedInt P = (t_mode == CEED_TRANSPOSE) ? P_ : Q_; 130 CeedInt Q = (t_mode == CEED_TRANSPOSE) ? Q_ : P_; 131 CeedInt NB = 1; 132 133 std::vector<std::array<int, RECORD_LENGTH_RTC> > *data = NULL; 134 data = (std::vector<std::array<int, RECORD_LENGTH_RTC> > *)nontensor_rtc_get_data(gpu_arch, precision, e_mode, t_mode); 135 136 int ir = -1; 137 double norm = std::numeric_limits<double>::max(); 138 for (size_t i = 0; i < data->size(); i++) { 139 int ip = (*data)[i][M_INDEX_RTC]; 140 int in = (*data)[i][N_INDEX_RTC]; 141 int iq = (*data)[i][K_INDEX_RTC]; 142 143 double pdiff = (double)(ip - P); 144 double ndiff = (double)(in - N); 145 double qdiff = (double)(iq - Q); 146 double nrm = sqrt(pdiff * pdiff + ndiff * ndiff + qdiff * qdiff); 147 148 if (nrm < norm) { 149 norm = nrm; 150 ir = i; 151 } 152 153 if (nrm == 0) { 154 // the input (m, n, k) exactly matches a record in `data` 155 // no need to search further 156 break; 157 } 158 } 159 160 if (ir >= 0) { 161 NB = (*data)[ir][NB_INDEX_RTC]; 162 } 163 164 return NB; 165 } 166