1*940a72f1SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2*940a72f1SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*940a72f1SSebastian Grimberg // 4*940a72f1SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5*940a72f1SSebastian Grimberg // 6*940a72f1SSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7*940a72f1SSebastian Grimberg 8*940a72f1SSebastian Grimberg #include "ceed-magma-gemm-selector.h" 9f80f4a74SSebastian Grimberg 10f80f4a74SSebastian Grimberg #include <array> 11f80f4a74SSebastian Grimberg #include <limits> 12f80f4a74SSebastian Grimberg #include <vector> 13f80f4a74SSebastian Grimberg 14f80f4a74SSebastian Grimberg #include "tuning/indices.h" 15f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP 16f80f4a74SSebastian Grimberg #include "tuning/mi100.h" 17f80f4a74SSebastian Grimberg #include "tuning/mi250x.h" 18f80f4a74SSebastian Grimberg #include "tuning/mi250x_grad_rtc.h" 19f80f4a74SSebastian Grimberg #include "tuning/mi250x_interp_rtc.h" 20f80f4a74SSebastian Grimberg #else 21f80f4a74SSebastian Grimberg #include "tuning/a100.h" 22f80f4a74SSebastian Grimberg #include "tuning/a100_grad_rtc.h" 23f80f4a74SSebastian Grimberg #include "tuning/a100_interp_rtc.h" 24f80f4a74SSebastian Grimberg #include "tuning/v100.h" 25f80f4a74SSebastian Grimberg #endif 26f80f4a74SSebastian Grimberg 27f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 28f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP 29*940a72f1SSebastian Grimberg static inline auto gemm_selector_get_data(int gpu_arch, char precision, char trans_A) -> decltype(dgemm_nn_mi250x) { 30f80f4a74SSebastian Grimberg if (gpu_arch >= 910) { 31f80f4a74SSebastian Grimberg // gfx90a or newer 32*940a72f1SSebastian Grimberg return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_mi250x : sgemm_tn_mi250x) : ((trans_A == 'n') ? dgemm_nn_mi250x : dgemm_tn_mi250x); 33f80f4a74SSebastian Grimberg } else { 34f80f4a74SSebastian Grimberg // gfx908 or older 35*940a72f1SSebastian Grimberg return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_mi100 : sgemm_tn_mi100) : ((trans_A == 'n') ? dgemm_nn_mi100 : dgemm_tn_mi100); 36*940a72f1SSebastian Grimberg } 37f80f4a74SSebastian Grimberg } 38f80f4a74SSebastian Grimberg #else 39*940a72f1SSebastian Grimberg static inline auto gemm_selector_get_data(int gpu_arch, char precision, char trans_A) -> decltype(dgemm_nn_a100) { 40f80f4a74SSebastian Grimberg if (gpu_arch >= 800) { 41f80f4a74SSebastian Grimberg // sm80 or newer 42*940a72f1SSebastian Grimberg return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_a100 : sgemm_tn_a100) : ((trans_A == 'n') ? dgemm_nn_a100 : dgemm_tn_a100); 43f80f4a74SSebastian Grimberg } else { 44f80f4a74SSebastian Grimberg // sm70 or older 45*940a72f1SSebastian Grimberg return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_v100 : sgemm_tn_v100) : ((trans_A == 'n') ? dgemm_nn_v100 : dgemm_tn_v100); 46*940a72f1SSebastian Grimberg } 47f80f4a74SSebastian Grimberg } 48f80f4a74SSebastian Grimberg #endif 49f80f4a74SSebastian Grimberg 50f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 51f80f4a74SSebastian Grimberg void gemm_selector(int gpu_arch, char precision, char trans_A, int m, int n, int k, int *n_batch, int *use_magma) { 52*940a72f1SSebastian Grimberg const auto &data = gemm_selector_get_data(gpu_arch, precision, trans_A); 53f80f4a74SSebastian Grimberg int ir = -1; 54f80f4a74SSebastian Grimberg double norm = std::numeric_limits<double>::max(); 55*940a72f1SSebastian Grimberg 56*940a72f1SSebastian Grimberg for (size_t i = 0; i < data.size(); i++) { 57*940a72f1SSebastian Grimberg const int &im = data[i][M_INDEX]; 58*940a72f1SSebastian Grimberg const int &in = data[i][N_INDEX]; 59*940a72f1SSebastian Grimberg const int &ik = data[i][K_INDEX]; 60f80f4a74SSebastian Grimberg 61f80f4a74SSebastian Grimberg double mdiff = (double)(im - m); 62f80f4a74SSebastian Grimberg double ndiff = (double)(in - n); 63f80f4a74SSebastian Grimberg double kdiff = (double)(ik - k); 64*940a72f1SSebastian Grimberg double nrm = mdiff * mdiff + ndiff * ndiff + kdiff * kdiff; 65f80f4a74SSebastian Grimberg 66f80f4a74SSebastian Grimberg if (nrm < norm) { 67f80f4a74SSebastian Grimberg norm = nrm; 68f80f4a74SSebastian Grimberg ir = i; 69f80f4a74SSebastian Grimberg } 70f80f4a74SSebastian Grimberg 71*940a72f1SSebastian Grimberg if (im == m && in == n && ik == k) { 72*940a72f1SSebastian Grimberg // The input (m, n, k) exactly matches a record in `data`, no need to search further 73f80f4a74SSebastian Grimberg break; 74f80f4a74SSebastian Grimberg } 75f80f4a74SSebastian Grimberg } 76f80f4a74SSebastian Grimberg 77f80f4a74SSebastian Grimberg if (ir >= 0) { 78*940a72f1SSebastian Grimberg // If the closest match indicates that n = n_batch, that means calling the regular non-batch GEMM. 79*940a72f1SSebastian Grimberg // So n_batch is set to n instead of the 'n_batch' entry of the matching record. 80*940a72f1SSebastian Grimberg int n_ = data[ir][N_INDEX]; 81*940a72f1SSebastian Grimberg int n_batch_ = data[ir][N_BATCH_INDEX]; 82f80f4a74SSebastian Grimberg *n_batch = (n_ == n_batch_) ? n : n_batch_; 83*940a72f1SSebastian Grimberg *use_magma = data[ir][USE_MAGMA_INDEX]; 84*940a72f1SSebastian Grimberg } else { 85*940a72f1SSebastian Grimberg *n_batch = n; 86*940a72f1SSebastian Grimberg *use_magma = 0; 87f80f4a74SSebastian Grimberg } 88f80f4a74SSebastian Grimberg } 89f80f4a74SSebastian Grimberg 90*940a72f1SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////// 91f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP 92*940a72f1SSebastian Grimberg static inline auto nontensor_rtc_get_data(int gpu_arch, char trans_A, int q_comp) -> decltype(dinterp_n_mi250x) { 93*940a72f1SSebastian Grimberg if (q_comp == 1) { 94*940a72f1SSebastian Grimberg return (trans_A == 'n') ? dinterp_n_mi250x : dinterp_t_mi250x; 95*940a72f1SSebastian Grimberg } else { 96*940a72f1SSebastian Grimberg return (trans_A == 'n') ? dgrad_n_mi250x : dgrad_t_mi250x; 97*940a72f1SSebastian Grimberg } 98f80f4a74SSebastian Grimberg } 99f80f4a74SSebastian Grimberg #else 100*940a72f1SSebastian Grimberg static inline auto nontensor_rtc_get_data(int gpu_arch, char trans_A, int q_comp) -> decltype(dinterp_n_a100) { 101*940a72f1SSebastian Grimberg if (q_comp == 1) { 102*940a72f1SSebastian Grimberg return (trans_A == 'n') ? dinterp_n_a100 : dinterp_t_a100; 103*940a72f1SSebastian Grimberg } else { 104*940a72f1SSebastian Grimberg return (trans_A == 'n') ? dgrad_n_a100 : dgrad_t_a100; 105*940a72f1SSebastian Grimberg } 106f80f4a74SSebastian Grimberg } 107f80f4a74SSebastian Grimberg #endif 108f80f4a74SSebastian Grimberg 109f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 110*940a72f1SSebastian Grimberg CeedInt nontensor_rtc_get_nb(int gpu_arch, char trans_A, int q_comp, int P, int Q, int n) { 111*940a72f1SSebastian Grimberg const auto &data = nontensor_rtc_get_data(gpu_arch, trans_A, q_comp); 112f80f4a74SSebastian Grimberg int ir = -1; 113f80f4a74SSebastian Grimberg double norm = std::numeric_limits<double>::max(); 114*940a72f1SSebastian Grimberg CeedInt m = (trans_A == 'n') ? Q : P; 115*940a72f1SSebastian Grimberg CeedInt k = (trans_A == 'n') ? P : Q; 116f80f4a74SSebastian Grimberg 117*940a72f1SSebastian Grimberg for (size_t i = 0; i < data.size(); i++) { 118*940a72f1SSebastian Grimberg const int &im = data[i][M_INDEX_RTC]; 119*940a72f1SSebastian Grimberg const int &in = data[i][N_INDEX_RTC]; 120*940a72f1SSebastian Grimberg const int &ik = data[i][K_INDEX_RTC]; 121*940a72f1SSebastian Grimberg 122*940a72f1SSebastian Grimberg double mdiff = (double)(im - m); 123*940a72f1SSebastian Grimberg double ndiff = (double)(in - n); 124*940a72f1SSebastian Grimberg double kdiff = (double)(ik - k); 125*940a72f1SSebastian Grimberg double nrm = mdiff * mdiff + ndiff * ndiff + kdiff * kdiff; 126f80f4a74SSebastian Grimberg 127f80f4a74SSebastian Grimberg if (nrm < norm) { 128f80f4a74SSebastian Grimberg norm = nrm; 129f80f4a74SSebastian Grimberg ir = i; 130f80f4a74SSebastian Grimberg } 131f80f4a74SSebastian Grimberg 132*940a72f1SSebastian Grimberg if (im == m && in == n && ik == k) { 133*940a72f1SSebastian Grimberg // The input (m, n, k) exactly matches a record in `data`, no need to search further 134f80f4a74SSebastian Grimberg break; 135f80f4a74SSebastian Grimberg } 136f80f4a74SSebastian Grimberg } 137f80f4a74SSebastian Grimberg 138*940a72f1SSebastian Grimberg return (ir >= 0) ? data[ir][NB_INDEX_RTC] : 1; 139f80f4a74SSebastian Grimberg } 140