1940a72f1SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2940a72f1SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3940a72f1SSebastian Grimberg // 4940a72f1SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5940a72f1SSebastian Grimberg // 6940a72f1SSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7940a72f1SSebastian Grimberg 8f80f4a74SSebastian Grimberg #include <array> 9f80f4a74SSebastian Grimberg #include <limits> 10f80f4a74SSebastian Grimberg #include <vector> 11f80f4a74SSebastian Grimberg 12913f8461SSebastian Grimberg #include "ceed-magma-gemm-selector.h" 13913f8461SSebastian Grimberg 14f80f4a74SSebastian Grimberg #include "tuning/indices.h" 15f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP 16f80f4a74SSebastian Grimberg #include "tuning/mi100.h" 177c7f2ed8SSebastian Grimberg #include "tuning/mi100_rtc.h" 18f80f4a74SSebastian Grimberg #include "tuning/mi250x.h" 197c7f2ed8SSebastian Grimberg #include "tuning/mi250x_rtc.h" 20f80f4a74SSebastian Grimberg #else 21f80f4a74SSebastian Grimberg #include "tuning/a100.h" 227c7f2ed8SSebastian Grimberg #include "tuning/a100_rtc.h" 23*dc215721SSebastian Grimberg #include "tuning/h100_rtc.h" 24f80f4a74SSebastian Grimberg #include "tuning/v100.h" 257c7f2ed8SSebastian Grimberg #include "tuning/v100_rtc.h" 26f80f4a74SSebastian Grimberg #endif 27f80f4a74SSebastian Grimberg 2826bdecf3SSebastian Grimberg // These definitions to force a certain parameter when generating autotuning data offline 297c7f2ed8SSebastian Grimberg // #define CEED_AUTOTUNE_GEMM_SELECTOR_N_BATCH 307c7f2ed8SSebastian Grimberg // #define CEED_AUTOTUNE_GEMM_SELECTOR_USE_MAGMA 317c7f2ed8SSebastian Grimberg // #define CEED_AUTOTUNE_RTC_NB 3226bdecf3SSebastian Grimberg 33f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 34f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP 35*dc215721SSebastian Grimberg static inline auto gemm_selector_get_data(int gpu_arch, char precision, char trans_A) -> decltype(dgemm_nn_mi100) { 36f80f4a74SSebastian Grimberg if (gpu_arch >= 910) { 37f80f4a74SSebastian Grimberg // gfx90a or newer 38940a72f1SSebastian Grimberg return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_mi250x : sgemm_tn_mi250x) : ((trans_A == 'n') ? dgemm_nn_mi250x : dgemm_tn_mi250x); 39f80f4a74SSebastian Grimberg } else { 40f80f4a74SSebastian Grimberg // gfx908 or older 41940a72f1SSebastian Grimberg return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_mi100 : sgemm_tn_mi100) : ((trans_A == 'n') ? dgemm_nn_mi100 : dgemm_tn_mi100); 42940a72f1SSebastian Grimberg } 43f80f4a74SSebastian Grimberg } 44f80f4a74SSebastian Grimberg #else 45*dc215721SSebastian Grimberg static inline auto gemm_selector_get_data(int gpu_arch, char precision, char trans_A) -> decltype(dgemm_nn_v100) { 46f80f4a74SSebastian Grimberg if (gpu_arch >= 800) { 47f80f4a74SSebastian Grimberg // sm80 or newer 48940a72f1SSebastian Grimberg return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_a100 : sgemm_tn_a100) : ((trans_A == 'n') ? dgemm_nn_a100 : dgemm_tn_a100); 49f80f4a74SSebastian Grimberg } else { 50f80f4a74SSebastian Grimberg // sm70 or older 51940a72f1SSebastian Grimberg return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_v100 : sgemm_tn_v100) : ((trans_A == 'n') ? dgemm_nn_v100 : dgemm_tn_v100); 52940a72f1SSebastian Grimberg } 53f80f4a74SSebastian Grimberg } 54f80f4a74SSebastian Grimberg #endif 55f80f4a74SSebastian Grimberg 56f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 57f80f4a74SSebastian Grimberg void gemm_selector(int gpu_arch, char precision, char trans_A, int m, int n, int k, int *n_batch, int *use_magma) { 5826bdecf3SSebastian Grimberg #if defined(CEED_AUTOTUNE_GEMM_SELECTOR_N_BATCH) && defined(CEED_AUTOTUNE_GEMM_SELECTOR_USE_MAGMA) 5926bdecf3SSebastian Grimberg *n_batch = CEED_AUTOTUNE_GEMM_SELECTOR_N_BATCH; 6026bdecf3SSebastian Grimberg *use_magma = CEED_AUTOTUNE_GEMM_SELECTOR_USE_MAGMA; 6126bdecf3SSebastian Grimberg #else 62940a72f1SSebastian Grimberg const auto &data = gemm_selector_get_data(gpu_arch, precision, trans_A); 63f80f4a74SSebastian Grimberg int ir = -1; 64f80f4a74SSebastian Grimberg double norm = std::numeric_limits<double>::max(); 65940a72f1SSebastian Grimberg 66940a72f1SSebastian Grimberg for (size_t i = 0; i < data.size(); i++) { 67940a72f1SSebastian Grimberg const int &im = data[i][M_INDEX]; 68940a72f1SSebastian Grimberg const int &in = data[i][N_INDEX]; 69940a72f1SSebastian Grimberg const int &ik = data[i][K_INDEX]; 70f80f4a74SSebastian Grimberg 71f80f4a74SSebastian Grimberg double mdiff = (double)(im - m); 72f80f4a74SSebastian Grimberg double ndiff = (double)(in - n); 73f80f4a74SSebastian Grimberg double kdiff = (double)(ik - k); 74940a72f1SSebastian Grimberg double nrm = mdiff * mdiff + ndiff * ndiff + kdiff * kdiff; 75f80f4a74SSebastian Grimberg 76f80f4a74SSebastian Grimberg if (nrm < norm) { 77f80f4a74SSebastian Grimberg norm = nrm; 78f80f4a74SSebastian Grimberg ir = i; 79f80f4a74SSebastian Grimberg } 80f80f4a74SSebastian Grimberg 81940a72f1SSebastian Grimberg if (im == m && in == n && ik == k) { 82940a72f1SSebastian Grimberg // The input (m, n, k) exactly matches a record in `data`, no need to search further 83f80f4a74SSebastian Grimberg break; 84f80f4a74SSebastian Grimberg } 85f80f4a74SSebastian Grimberg } 86f80f4a74SSebastian Grimberg 87f80f4a74SSebastian Grimberg if (ir >= 0) { 88940a72f1SSebastian Grimberg // If the closest match indicates that n = n_batch, that means calling the regular non-batch GEMM. 89940a72f1SSebastian Grimberg // So n_batch is set to n instead of the 'n_batch' entry of the matching record. 90940a72f1SSebastian Grimberg int n_ = data[ir][N_INDEX]; 91940a72f1SSebastian Grimberg int n_batch_ = data[ir][N_BATCH_INDEX]; 92f80f4a74SSebastian Grimberg *n_batch = (n_ == n_batch_) ? n : n_batch_; 93940a72f1SSebastian Grimberg *use_magma = data[ir][USE_MAGMA_INDEX]; 94940a72f1SSebastian Grimberg } else { 95940a72f1SSebastian Grimberg *n_batch = n; 96940a72f1SSebastian Grimberg *use_magma = 0; 97f80f4a74SSebastian Grimberg } 9826bdecf3SSebastian Grimberg #endif 99f80f4a74SSebastian Grimberg } 100f80f4a74SSebastian Grimberg 101940a72f1SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////// 102f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP 103*dc215721SSebastian Grimberg static inline auto nontensor_rtc_get_data(int gpu_arch, char trans_A) -> decltype(drtc_n_mi100) { 1047c7f2ed8SSebastian Grimberg if (gpu_arch >= 910) { 1057c7f2ed8SSebastian Grimberg // gfx90a or newer 1067c7f2ed8SSebastian Grimberg return (trans_A == 'n') ? drtc_n_mi250x : drtc_t_mi250x; 107940a72f1SSebastian Grimberg } else { 1087c7f2ed8SSebastian Grimberg // gfx908 or older 1097c7f2ed8SSebastian Grimberg return (trans_A == 'n') ? drtc_n_mi100 : drtc_t_mi100; 110940a72f1SSebastian Grimberg } 111f80f4a74SSebastian Grimberg } 112f80f4a74SSebastian Grimberg #else 113*dc215721SSebastian Grimberg static inline auto nontensor_rtc_get_data(int gpu_arch, char trans_A) -> decltype(drtc_n_v100) { 114*dc215721SSebastian Grimberg if (gpu_arch >= 900) { 115*dc215721SSebastian Grimberg // sm90 or newer 116*dc215721SSebastian Grimberg return (trans_A == 'n') ? drtc_n_h100 : drtc_t_h100; 117*dc215721SSebastian Grimberg } else if (gpu_arch >= 800) { 1187c7f2ed8SSebastian Grimberg // sm80 or newer 1197c7f2ed8SSebastian Grimberg return (trans_A == 'n') ? drtc_n_a100 : drtc_t_a100; 120940a72f1SSebastian Grimberg } else { 1217c7f2ed8SSebastian Grimberg // sm70 or older 1227c7f2ed8SSebastian Grimberg return (trans_A == 'n') ? drtc_n_v100 : drtc_t_v100; 123940a72f1SSebastian Grimberg } 124f80f4a74SSebastian Grimberg } 125f80f4a74SSebastian Grimberg #endif 126f80f4a74SSebastian Grimberg 127f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 1287c7f2ed8SSebastian Grimberg CeedInt nontensor_rtc_get_nb(int gpu_arch, char trans_A, int q_comp, int P, int Q, int N) { 12926bdecf3SSebastian Grimberg #ifdef CEED_AUTOTUNE_RTC_NB 13026bdecf3SSebastian Grimberg return CEED_AUTOTUNE_RTC_NB; 13126bdecf3SSebastian Grimberg #else 1327c7f2ed8SSebastian Grimberg const auto &data = nontensor_rtc_get_data(gpu_arch, trans_A); 133f80f4a74SSebastian Grimberg int ir = -1; 134f80f4a74SSebastian Grimberg double norm = std::numeric_limits<double>::max(); 135f80f4a74SSebastian Grimberg 136940a72f1SSebastian Grimberg for (size_t i = 0; i < data.size(); i++) { 1377c7f2ed8SSebastian Grimberg // Only seach exact matches for q_comp 1387c7f2ed8SSebastian Grimberg if (q_comp != data[i][Q_COMP_INDEX_RTC]) { 1397c7f2ed8SSebastian Grimberg continue; 1407c7f2ed8SSebastian Grimberg } 141940a72f1SSebastian Grimberg 1427c7f2ed8SSebastian Grimberg const int &iP = data[i][P_INDEX_RTC]; 1437c7f2ed8SSebastian Grimberg const int &iQ = data[i][Q_INDEX_RTC]; 1447c7f2ed8SSebastian Grimberg const int &iN = data[i][N_INDEX_RTC]; 1457c7f2ed8SSebastian Grimberg 1467c7f2ed8SSebastian Grimberg double Pdiff = (double)(iP - P); 1477c7f2ed8SSebastian Grimberg double Qdiff = (double)(iQ - Q); 1487c7f2ed8SSebastian Grimberg double Ndiff = (double)(iN - N); 1497c7f2ed8SSebastian Grimberg double nrm = Pdiff * Pdiff + Qdiff * Qdiff + Ndiff * Ndiff; 150f80f4a74SSebastian Grimberg 151f80f4a74SSebastian Grimberg if (nrm < norm) { 152f80f4a74SSebastian Grimberg norm = nrm; 153f80f4a74SSebastian Grimberg ir = i; 154f80f4a74SSebastian Grimberg } 155f80f4a74SSebastian Grimberg 1567c7f2ed8SSebastian Grimberg if (iP == P && iQ == Q && iN == N) { 1577c7f2ed8SSebastian Grimberg // The input (P, Q, N) exactly matches a record in `data`, no need to search further 158f80f4a74SSebastian Grimberg break; 159f80f4a74SSebastian Grimberg } 160f80f4a74SSebastian Grimberg } 161f80f4a74SSebastian Grimberg 162940a72f1SSebastian Grimberg return (ir >= 0) ? data[ir][NB_INDEX_RTC] : 1; 16326bdecf3SSebastian Grimberg #endif 164f80f4a74SSebastian Grimberg } 165