xref: /libCEED/rust/libceed-sys/c-src/backends/magma/ceed-magma-gemm-selector.cpp (revision d4cc18453651bd0f94c1a2e078b2646a92dafdcc)
1*9ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2940a72f1SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3940a72f1SSebastian Grimberg //
4940a72f1SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5940a72f1SSebastian Grimberg //
6940a72f1SSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7940a72f1SSebastian Grimberg 
8f80f4a74SSebastian Grimberg #include <array>
9f80f4a74SSebastian Grimberg #include <limits>
10f80f4a74SSebastian Grimberg #include <vector>
11f80f4a74SSebastian Grimberg 
12913f8461SSebastian Grimberg #include "ceed-magma-gemm-selector.h"
13913f8461SSebastian Grimberg 
14f80f4a74SSebastian Grimberg #include "tuning/indices.h"
15f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP
16f80f4a74SSebastian Grimberg #include "tuning/mi100.h"
177c7f2ed8SSebastian Grimberg #include "tuning/mi100_rtc.h"
18f80f4a74SSebastian Grimberg #include "tuning/mi250x.h"
197c7f2ed8SSebastian Grimberg #include "tuning/mi250x_rtc.h"
20f80f4a74SSebastian Grimberg #else
21f80f4a74SSebastian Grimberg #include "tuning/a100.h"
227c7f2ed8SSebastian Grimberg #include "tuning/a100_rtc.h"
23dc215721SSebastian Grimberg #include "tuning/h100_rtc.h"
24f80f4a74SSebastian Grimberg #include "tuning/v100.h"
257c7f2ed8SSebastian Grimberg #include "tuning/v100_rtc.h"
26f80f4a74SSebastian Grimberg #endif
27f80f4a74SSebastian Grimberg 
2826bdecf3SSebastian Grimberg // These definitions to force a certain parameter when generating autotuning data offline
297c7f2ed8SSebastian Grimberg // #define CEED_AUTOTUNE_GEMM_SELECTOR_N_BATCH
307c7f2ed8SSebastian Grimberg // #define CEED_AUTOTUNE_GEMM_SELECTOR_USE_MAGMA
317c7f2ed8SSebastian Grimberg // #define CEED_AUTOTUNE_RTC_NB
3226bdecf3SSebastian Grimberg 
33f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
34f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP
gemm_selector_get_data(int gpu_arch,char precision,char trans_A)35dc215721SSebastian Grimberg static inline auto gemm_selector_get_data(int gpu_arch, char precision, char trans_A) -> decltype(dgemm_nn_mi100) {
36f80f4a74SSebastian Grimberg   if (gpu_arch >= 910) {
37f80f4a74SSebastian Grimberg     // gfx90a or newer
38940a72f1SSebastian Grimberg     return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_mi250x : sgemm_tn_mi250x) : ((trans_A == 'n') ? dgemm_nn_mi250x : dgemm_tn_mi250x);
39f80f4a74SSebastian Grimberg   } else {
40f80f4a74SSebastian Grimberg     // gfx908 or older
41940a72f1SSebastian Grimberg     return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_mi100 : sgemm_tn_mi100) : ((trans_A == 'n') ? dgemm_nn_mi100 : dgemm_tn_mi100);
42940a72f1SSebastian Grimberg   }
43f80f4a74SSebastian Grimberg }
44f80f4a74SSebastian Grimberg #else
gemm_selector_get_data(int gpu_arch,char precision,char trans_A)45dc215721SSebastian Grimberg static inline auto gemm_selector_get_data(int gpu_arch, char precision, char trans_A) -> decltype(dgemm_nn_v100) {
46f80f4a74SSebastian Grimberg   if (gpu_arch >= 800) {
47f80f4a74SSebastian Grimberg     // sm80 or newer
48940a72f1SSebastian Grimberg     return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_a100 : sgemm_tn_a100) : ((trans_A == 'n') ? dgemm_nn_a100 : dgemm_tn_a100);
49f80f4a74SSebastian Grimberg   } else {
50f80f4a74SSebastian Grimberg     // sm70 or older
51940a72f1SSebastian Grimberg     return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_v100 : sgemm_tn_v100) : ((trans_A == 'n') ? dgemm_nn_v100 : dgemm_tn_v100);
52940a72f1SSebastian Grimberg   }
53f80f4a74SSebastian Grimberg }
54f80f4a74SSebastian Grimberg #endif
55f80f4a74SSebastian Grimberg 
56f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
gemm_selector(int gpu_arch,char precision,char trans_A,int m,int n,int k,int * n_batch,int * use_magma)57f80f4a74SSebastian Grimberg void gemm_selector(int gpu_arch, char precision, char trans_A, int m, int n, int k, int *n_batch, int *use_magma) {
5826bdecf3SSebastian Grimberg #if defined(CEED_AUTOTUNE_GEMM_SELECTOR_N_BATCH) && defined(CEED_AUTOTUNE_GEMM_SELECTOR_USE_MAGMA)
5926bdecf3SSebastian Grimberg   *n_batch   = CEED_AUTOTUNE_GEMM_SELECTOR_N_BATCH;
6026bdecf3SSebastian Grimberg   *use_magma = CEED_AUTOTUNE_GEMM_SELECTOR_USE_MAGMA;
6126bdecf3SSebastian Grimberg #else
62940a72f1SSebastian Grimberg   const auto &data = gemm_selector_get_data(gpu_arch, precision, trans_A);
63f80f4a74SSebastian Grimberg   int         ir   = -1;
64f80f4a74SSebastian Grimberg   double      norm = std::numeric_limits<double>::max();
65940a72f1SSebastian Grimberg 
66940a72f1SSebastian Grimberg   for (size_t i = 0; i < data.size(); i++) {
67940a72f1SSebastian Grimberg     const int &im = data[i][M_INDEX];
68940a72f1SSebastian Grimberg     const int &in = data[i][N_INDEX];
69940a72f1SSebastian Grimberg     const int &ik = data[i][K_INDEX];
70f80f4a74SSebastian Grimberg 
71f80f4a74SSebastian Grimberg     double mdiff = (double)(im - m);
72f80f4a74SSebastian Grimberg     double ndiff = (double)(in - n);
73f80f4a74SSebastian Grimberg     double kdiff = (double)(ik - k);
74940a72f1SSebastian Grimberg     double nrm   = mdiff * mdiff + ndiff * ndiff + kdiff * kdiff;
75f80f4a74SSebastian Grimberg 
76f80f4a74SSebastian Grimberg     if (nrm < norm) {
77f80f4a74SSebastian Grimberg       norm = nrm;
78f80f4a74SSebastian Grimberg       ir   = i;
79f80f4a74SSebastian Grimberg     }
80f80f4a74SSebastian Grimberg 
81940a72f1SSebastian Grimberg     if (im == m && in == n && ik == k) {
82940a72f1SSebastian Grimberg       // The input (m, n, k) exactly matches a record in `data`, no need to search further
83f80f4a74SSebastian Grimberg       break;
84f80f4a74SSebastian Grimberg     }
85f80f4a74SSebastian Grimberg   }
86f80f4a74SSebastian Grimberg 
87f80f4a74SSebastian Grimberg   if (ir >= 0) {
88940a72f1SSebastian Grimberg     // If the closest match indicates that n = n_batch, that means calling the regular non-batch GEMM.
89940a72f1SSebastian Grimberg     // So n_batch is set to n instead of the 'n_batch' entry of the matching record.
90940a72f1SSebastian Grimberg     int n_       = data[ir][N_INDEX];
91940a72f1SSebastian Grimberg     int n_batch_ = data[ir][N_BATCH_INDEX];
92f80f4a74SSebastian Grimberg     *n_batch     = (n_ == n_batch_) ? n : n_batch_;
93940a72f1SSebastian Grimberg     *use_magma   = data[ir][USE_MAGMA_INDEX];
94940a72f1SSebastian Grimberg   } else {
95940a72f1SSebastian Grimberg     *n_batch   = n;
96940a72f1SSebastian Grimberg     *use_magma = 0;
97f80f4a74SSebastian Grimberg   }
9826bdecf3SSebastian Grimberg #endif
99f80f4a74SSebastian Grimberg }
100f80f4a74SSebastian Grimberg 
101940a72f1SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////
102f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP
nontensor_rtc_get_data(int gpu_arch,char trans_A)103dc215721SSebastian Grimberg static inline auto nontensor_rtc_get_data(int gpu_arch, char trans_A) -> decltype(drtc_n_mi100) {
1047c7f2ed8SSebastian Grimberg   if (gpu_arch >= 910) {
1057c7f2ed8SSebastian Grimberg     // gfx90a or newer
1067c7f2ed8SSebastian Grimberg     return (trans_A == 'n') ? drtc_n_mi250x : drtc_t_mi250x;
107940a72f1SSebastian Grimberg   } else {
1087c7f2ed8SSebastian Grimberg     // gfx908 or older
1097c7f2ed8SSebastian Grimberg     return (trans_A == 'n') ? drtc_n_mi100 : drtc_t_mi100;
110940a72f1SSebastian Grimberg   }
111f80f4a74SSebastian Grimberg }
112f80f4a74SSebastian Grimberg #else
nontensor_rtc_get_data(int gpu_arch,char trans_A)113dc215721SSebastian Grimberg static inline auto nontensor_rtc_get_data(int gpu_arch, char trans_A) -> decltype(drtc_n_v100) {
114dc215721SSebastian Grimberg   if (gpu_arch >= 900) {
115dc215721SSebastian Grimberg     // sm90 or newer
116dc215721SSebastian Grimberg     return (trans_A == 'n') ? drtc_n_h100 : drtc_t_h100;
117dc215721SSebastian Grimberg   } else if (gpu_arch >= 800) {
1187c7f2ed8SSebastian Grimberg     // sm80 or newer
1197c7f2ed8SSebastian Grimberg     return (trans_A == 'n') ? drtc_n_a100 : drtc_t_a100;
120940a72f1SSebastian Grimberg   } else {
1217c7f2ed8SSebastian Grimberg     // sm70 or older
1227c7f2ed8SSebastian Grimberg     return (trans_A == 'n') ? drtc_n_v100 : drtc_t_v100;
123940a72f1SSebastian Grimberg   }
124f80f4a74SSebastian Grimberg }
125f80f4a74SSebastian Grimberg #endif
126f80f4a74SSebastian Grimberg 
127f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
nontensor_rtc_get_nb(int gpu_arch,char trans_A,int q_comp,int P,int Q,int N)1287c7f2ed8SSebastian Grimberg CeedInt nontensor_rtc_get_nb(int gpu_arch, char trans_A, int q_comp, int P, int Q, int N) {
12926bdecf3SSebastian Grimberg #ifdef CEED_AUTOTUNE_RTC_NB
13026bdecf3SSebastian Grimberg   return CEED_AUTOTUNE_RTC_NB;
13126bdecf3SSebastian Grimberg #else
1327c7f2ed8SSebastian Grimberg   const auto &data = nontensor_rtc_get_data(gpu_arch, trans_A);
133f80f4a74SSebastian Grimberg   int         ir   = -1;
134f80f4a74SSebastian Grimberg   double      norm = std::numeric_limits<double>::max();
135f80f4a74SSebastian Grimberg 
136940a72f1SSebastian Grimberg   for (size_t i = 0; i < data.size(); i++) {
1377c7f2ed8SSebastian Grimberg     // Only seach exact matches for q_comp
1387c7f2ed8SSebastian Grimberg     if (q_comp != data[i][Q_COMP_INDEX_RTC]) {
1397c7f2ed8SSebastian Grimberg       continue;
1407c7f2ed8SSebastian Grimberg     }
141940a72f1SSebastian Grimberg 
1427c7f2ed8SSebastian Grimberg     const int &iP = data[i][P_INDEX_RTC];
1437c7f2ed8SSebastian Grimberg     const int &iQ = data[i][Q_INDEX_RTC];
1447c7f2ed8SSebastian Grimberg     const int &iN = data[i][N_INDEX_RTC];
1457c7f2ed8SSebastian Grimberg 
1467c7f2ed8SSebastian Grimberg     double Pdiff = (double)(iP - P);
1477c7f2ed8SSebastian Grimberg     double Qdiff = (double)(iQ - Q);
1487c7f2ed8SSebastian Grimberg     double Ndiff = (double)(iN - N);
1497c7f2ed8SSebastian Grimberg     double nrm   = Pdiff * Pdiff + Qdiff * Qdiff + Ndiff * Ndiff;
150f80f4a74SSebastian Grimberg 
151f80f4a74SSebastian Grimberg     if (nrm < norm) {
152f80f4a74SSebastian Grimberg       norm = nrm;
153f80f4a74SSebastian Grimberg       ir   = i;
154f80f4a74SSebastian Grimberg     }
155f80f4a74SSebastian Grimberg 
1567c7f2ed8SSebastian Grimberg     if (iP == P && iQ == Q && iN == N) {
1577c7f2ed8SSebastian Grimberg       // The input (P, Q, N) exactly matches a record in `data`, no need to search further
158f80f4a74SSebastian Grimberg       break;
159f80f4a74SSebastian Grimberg     }
160f80f4a74SSebastian Grimberg   }
161f80f4a74SSebastian Grimberg 
162940a72f1SSebastian Grimberg   return (ir >= 0) ? data[ir][NB_INDEX_RTC] : 1;
16326bdecf3SSebastian Grimberg #endif
164f80f4a74SSebastian Grimberg }
165