1*f80f4a74SSebastian Grimberg #include <stdio.h> 2*f80f4a74SSebastian Grimberg #include <sys/time.h> 3*f80f4a74SSebastian Grimberg 4*f80f4a74SSebastian Grimberg #include <array> 5*f80f4a74SSebastian Grimberg #include <limits> 6*f80f4a74SSebastian Grimberg #include <vector> 7*f80f4a74SSebastian Grimberg 8*f80f4a74SSebastian Grimberg #include "ceed-magma.h" 9*f80f4a74SSebastian Grimberg #include "tuning/indices.h" 10*f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP 11*f80f4a74SSebastian Grimberg #include "tuning/mi100.h" 12*f80f4a74SSebastian Grimberg #include "tuning/mi250x.h" 13*f80f4a74SSebastian Grimberg #include "tuning/mi250x_grad_rtc.h" 14*f80f4a74SSebastian Grimberg #include "tuning/mi250x_interp_rtc.h" 15*f80f4a74SSebastian Grimberg #else 16*f80f4a74SSebastian Grimberg #include "tuning/a100.h" 17*f80f4a74SSebastian Grimberg #include "tuning/a100_grad_rtc.h" 18*f80f4a74SSebastian Grimberg #include "tuning/a100_interp_rtc.h" 19*f80f4a74SSebastian Grimberg #include "tuning/v100.h" 20*f80f4a74SSebastian Grimberg #endif 21*f80f4a74SSebastian Grimberg 22*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 23*f80f4a74SSebastian Grimberg static void *gemm_selector_get_data(int gpu_arch, char precision, char trans_A) { 24*f80f4a74SSebastian Grimberg // a default 25*f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP 26*f80f4a74SSebastian Grimberg void *data = (void *)&sgemm_nn_mi250x; 27*f80f4a74SSebastian Grimberg #else 28*f80f4a74SSebastian Grimberg void *data = (void *)&sgemm_nn_a100; 29*f80f4a74SSebastian Grimberg #endif 30*f80f4a74SSebastian Grimberg 31*f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP 32*f80f4a74SSebastian Grimberg if (gpu_arch >= 910) { 33*f80f4a74SSebastian Grimberg // gfx90a or newer 34*f80f4a74SSebastian Grimberg data = (precision == 's') ? ((trans_A == 'n') ? (void *)&sgemm_nn_mi250x : (void *)&sgemm_tn_mi250x) 35*f80f4a74SSebastian Grimberg : ((trans_A == 'n') ? (void *)&dgemm_nn_mi250x : (void *)&dgemm_tn_mi250x); 36*f80f4a74SSebastian Grimberg } else { 37*f80f4a74SSebastian Grimberg // gfx908 or older 38*f80f4a74SSebastian Grimberg data = (precision == 's') ? ((trans_A == 'n') ? (void *)&sgemm_nn_mi100 : (void *)&sgemm_tn_mi100) 39*f80f4a74SSebastian Grimberg : ((trans_A == 'n') ? (void *)&dgemm_nn_mi100 : (void *)&dgemm_tn_mi100); 40*f80f4a74SSebastian Grimberg } 41*f80f4a74SSebastian Grimberg #else 42*f80f4a74SSebastian Grimberg if (gpu_arch >= 800) { 43*f80f4a74SSebastian Grimberg // sm80 or newer 44*f80f4a74SSebastian Grimberg data = (precision == 's') ? ((trans_A == 'n') ? (void *)&sgemm_nn_a100 : (void *)&sgemm_tn_a100) 45*f80f4a74SSebastian Grimberg : ((trans_A == 'n') ? (void *)&dgemm_nn_a100 : (void *)&dgemm_tn_a100); 46*f80f4a74SSebastian Grimberg } else { 47*f80f4a74SSebastian Grimberg // sm70 or older 48*f80f4a74SSebastian Grimberg data = (precision == 's') ? ((trans_A == 'n') ? (void *)&sgemm_nn_v100 : (void *)&sgemm_tn_v100) 49*f80f4a74SSebastian Grimberg : ((trans_A == 'n') ? (void *)&dgemm_nn_v100 : (void *)&dgemm_tn_v100); 50*f80f4a74SSebastian Grimberg } 51*f80f4a74SSebastian Grimberg #endif 52*f80f4a74SSebastian Grimberg 53*f80f4a74SSebastian Grimberg return data; 54*f80f4a74SSebastian Grimberg } 55*f80f4a74SSebastian Grimberg 56*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 57*f80f4a74SSebastian Grimberg void gemm_selector(int gpu_arch, char precision, char trans_A, int m, int n, int k, int *n_batch, int *use_magma) { 58*f80f4a74SSebastian Grimberg // defaults 59*f80f4a74SSebastian Grimberg *n_batch = n; 60*f80f4a74SSebastian Grimberg *use_magma = 0; 61*f80f4a74SSebastian Grimberg std::vector<std::array<int, RECORD_LENGTH> > *data = NULL; 62*f80f4a74SSebastian Grimberg data = (std::vector<std::array<int, RECORD_LENGTH> > *)gemm_selector_get_data(gpu_arch, precision, trans_A); 63*f80f4a74SSebastian Grimberg 64*f80f4a74SSebastian Grimberg int ir = -1; 65*f80f4a74SSebastian Grimberg double norm = std::numeric_limits<double>::max(); 66*f80f4a74SSebastian Grimberg for (size_t i = 0; i < data->size(); i++) { 67*f80f4a74SSebastian Grimberg int im = (*data)[i][M_INDEX]; 68*f80f4a74SSebastian Grimberg int in = (*data)[i][N_INDEX]; 69*f80f4a74SSebastian Grimberg int ik = (*data)[i][K_INDEX]; 70*f80f4a74SSebastian Grimberg 71*f80f4a74SSebastian Grimberg double mdiff = (double)(im - m); 72*f80f4a74SSebastian Grimberg double ndiff = (double)(in - n); 73*f80f4a74SSebastian Grimberg double kdiff = (double)(ik - k); 74*f80f4a74SSebastian Grimberg 75*f80f4a74SSebastian Grimberg double nrm = sqrt(mdiff * mdiff + ndiff * ndiff + kdiff * kdiff); 76*f80f4a74SSebastian Grimberg 77*f80f4a74SSebastian Grimberg if (nrm < norm) { 78*f80f4a74SSebastian Grimberg norm = nrm; 79*f80f4a74SSebastian Grimberg ir = i; 80*f80f4a74SSebastian Grimberg } 81*f80f4a74SSebastian Grimberg 82*f80f4a74SSebastian Grimberg if (nrm == 0) { 83*f80f4a74SSebastian Grimberg // the input (m, n, k) exactly matches a record in `data` 84*f80f4a74SSebastian Grimberg // no need to search further 85*f80f4a74SSebastian Grimberg break; 86*f80f4a74SSebastian Grimberg } 87*f80f4a74SSebastian Grimberg } 88*f80f4a74SSebastian Grimberg 89*f80f4a74SSebastian Grimberg if (ir >= 0) { 90*f80f4a74SSebastian Grimberg *use_magma = (*data)[ir][USE_MAGMA_INDEX]; 91*f80f4a74SSebastian Grimberg // if the closest match indicates that n = n_batch, 92*f80f4a74SSebastian Grimberg // that means calling the regular non-batch gemm. 93*f80f4a74SSebastian Grimberg // So n_batch is set to n instead of the 'n_batch' 94*f80f4a74SSebastian Grimberg // entry of the matching record 95*f80f4a74SSebastian Grimberg int n_ = (*data)[ir][N_INDEX]; 96*f80f4a74SSebastian Grimberg int n_batch_ = (*data)[ir][N_BATCH_INDEX]; 97*f80f4a74SSebastian Grimberg *n_batch = (n_ == n_batch_) ? n : n_batch_; 98*f80f4a74SSebastian Grimberg } 99*f80f4a74SSebastian Grimberg } 100*f80f4a74SSebastian Grimberg 101*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 102*f80f4a74SSebastian Grimberg static void *nontensor_rtc_get_data(int gpu_arch, char precision, CeedEvalMode e_mode, CeedTransposeMode t_mode) { 103*f80f4a74SSebastian Grimberg // a default 104*f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP 105*f80f4a74SSebastian Grimberg void *data = (void *)&dinterp_n_mi250x; 106*f80f4a74SSebastian Grimberg #else 107*f80f4a74SSebastian Grimberg void *data = (void *)&dinterp_n_a100; 108*f80f4a74SSebastian Grimberg #endif 109*f80f4a74SSebastian Grimberg 110*f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP 111*f80f4a74SSebastian Grimberg if (e_mode == CEED_EVAL_INTERP) { 112*f80f4a74SSebastian Grimberg data = (t_mode == CEED_TRANSPOSE) ? (void *)&dinterp_t_mi250x : (void *)&dinterp_n_mi250x; 113*f80f4a74SSebastian Grimberg } else if (e_mode == CEED_EVAL_GRAD) { 114*f80f4a74SSebastian Grimberg data = (t_mode == CEED_TRANSPOSE) ? (void *)&dgrad_t_mi250x : (void *)&dgrad_n_mi250x; 115*f80f4a74SSebastian Grimberg } 116*f80f4a74SSebastian Grimberg #else 117*f80f4a74SSebastian Grimberg if (e_mode == CEED_EVAL_INTERP) { 118*f80f4a74SSebastian Grimberg data = (t_mode == CEED_TRANSPOSE) ? (void *)&dinterp_t_a100 : (void *)&dinterp_n_a100; 119*f80f4a74SSebastian Grimberg } else if (e_mode == CEED_EVAL_GRAD) { 120*f80f4a74SSebastian Grimberg data = (t_mode == CEED_TRANSPOSE) ? (void *)&dgrad_t_a100 : (void *)&dgrad_n_a100; 121*f80f4a74SSebastian Grimberg } 122*f80f4a74SSebastian Grimberg #endif 123*f80f4a74SSebastian Grimberg 124*f80f4a74SSebastian Grimberg return data; 125*f80f4a74SSebastian Grimberg } 126*f80f4a74SSebastian Grimberg 127*f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 128*f80f4a74SSebastian Grimberg CeedInt nontensor_rtc_get_nb(int gpu_arch, char precision, CeedEvalMode e_mode, CeedTransposeMode t_mode, int P_, int N, int Q_) { 129*f80f4a74SSebastian Grimberg CeedInt P = (t_mode == CEED_TRANSPOSE) ? P_ : Q_; 130*f80f4a74SSebastian Grimberg CeedInt Q = (t_mode == CEED_TRANSPOSE) ? Q_ : P_; 131*f80f4a74SSebastian Grimberg CeedInt NB = 1; 132*f80f4a74SSebastian Grimberg 133*f80f4a74SSebastian Grimberg std::vector<std::array<int, RECORD_LENGTH_RTC> > *data = NULL; 134*f80f4a74SSebastian Grimberg data = (std::vector<std::array<int, RECORD_LENGTH_RTC> > *)nontensor_rtc_get_data(gpu_arch, precision, e_mode, t_mode); 135*f80f4a74SSebastian Grimberg 136*f80f4a74SSebastian Grimberg int ir = -1; 137*f80f4a74SSebastian Grimberg double norm = std::numeric_limits<double>::max(); 138*f80f4a74SSebastian Grimberg for (size_t i = 0; i < data->size(); i++) { 139*f80f4a74SSebastian Grimberg int ip = (*data)[i][M_INDEX_RTC]; 140*f80f4a74SSebastian Grimberg int in = (*data)[i][N_INDEX_RTC]; 141*f80f4a74SSebastian Grimberg int iq = (*data)[i][K_INDEX_RTC]; 142*f80f4a74SSebastian Grimberg 143*f80f4a74SSebastian Grimberg double pdiff = (double)(ip - P); 144*f80f4a74SSebastian Grimberg double ndiff = (double)(in - N); 145*f80f4a74SSebastian Grimberg double qdiff = (double)(iq - Q); 146*f80f4a74SSebastian Grimberg double nrm = sqrt(pdiff * pdiff + ndiff * ndiff + qdiff * qdiff); 147*f80f4a74SSebastian Grimberg 148*f80f4a74SSebastian Grimberg if (nrm < norm) { 149*f80f4a74SSebastian Grimberg norm = nrm; 150*f80f4a74SSebastian Grimberg ir = i; 151*f80f4a74SSebastian Grimberg } 152*f80f4a74SSebastian Grimberg 153*f80f4a74SSebastian Grimberg if (nrm == 0) { 154*f80f4a74SSebastian Grimberg // the input (m, n, k) exactly matches a record in `data` 155*f80f4a74SSebastian Grimberg // no need to search further 156*f80f4a74SSebastian Grimberg break; 157*f80f4a74SSebastian Grimberg } 158*f80f4a74SSebastian Grimberg } 159*f80f4a74SSebastian Grimberg 160*f80f4a74SSebastian Grimberg if (ir >= 0) { 161*f80f4a74SSebastian Grimberg NB = (*data)[ir][NB_INDEX_RTC]; 162*f80f4a74SSebastian Grimberg } 163*f80f4a74SSebastian Grimberg 164*f80f4a74SSebastian Grimberg return NB; 165*f80f4a74SSebastian Grimberg } 166