xref: /libCEED/rust/libceed-sys/c-src/backends/magma/ceed-magma-gemm-selector.cpp (revision 940a72f1a85a7fc8459dbc83c7f6f7637fe1955b)
1*940a72f1SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*940a72f1SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*940a72f1SSebastian Grimberg //
4*940a72f1SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5*940a72f1SSebastian Grimberg //
6*940a72f1SSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7*940a72f1SSebastian Grimberg 
8*940a72f1SSebastian Grimberg #include "ceed-magma-gemm-selector.h"
9f80f4a74SSebastian Grimberg 
10f80f4a74SSebastian Grimberg #include <array>
11f80f4a74SSebastian Grimberg #include <limits>
12f80f4a74SSebastian Grimberg #include <vector>
13f80f4a74SSebastian Grimberg 
14f80f4a74SSebastian Grimberg #include "tuning/indices.h"
15f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP
16f80f4a74SSebastian Grimberg #include "tuning/mi100.h"
17f80f4a74SSebastian Grimberg #include "tuning/mi250x.h"
18f80f4a74SSebastian Grimberg #include "tuning/mi250x_grad_rtc.h"
19f80f4a74SSebastian Grimberg #include "tuning/mi250x_interp_rtc.h"
20f80f4a74SSebastian Grimberg #else
21f80f4a74SSebastian Grimberg #include "tuning/a100.h"
22f80f4a74SSebastian Grimberg #include "tuning/a100_grad_rtc.h"
23f80f4a74SSebastian Grimberg #include "tuning/a100_interp_rtc.h"
24f80f4a74SSebastian Grimberg #include "tuning/v100.h"
25f80f4a74SSebastian Grimberg #endif
26f80f4a74SSebastian Grimberg 
27f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
28f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP
29*940a72f1SSebastian Grimberg static inline auto gemm_selector_get_data(int gpu_arch, char precision, char trans_A) -> decltype(dgemm_nn_mi250x) {
30f80f4a74SSebastian Grimberg   if (gpu_arch >= 910) {
31f80f4a74SSebastian Grimberg     // gfx90a or newer
32*940a72f1SSebastian Grimberg     return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_mi250x : sgemm_tn_mi250x) : ((trans_A == 'n') ? dgemm_nn_mi250x : dgemm_tn_mi250x);
33f80f4a74SSebastian Grimberg   } else {
34f80f4a74SSebastian Grimberg     // gfx908 or older
35*940a72f1SSebastian Grimberg     return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_mi100 : sgemm_tn_mi100) : ((trans_A == 'n') ? dgemm_nn_mi100 : dgemm_tn_mi100);
36*940a72f1SSebastian Grimberg   }
37f80f4a74SSebastian Grimberg }
38f80f4a74SSebastian Grimberg #else
39*940a72f1SSebastian Grimberg static inline auto gemm_selector_get_data(int gpu_arch, char precision, char trans_A) -> decltype(dgemm_nn_a100) {
40f80f4a74SSebastian Grimberg   if (gpu_arch >= 800) {
41f80f4a74SSebastian Grimberg     // sm80 or newer
42*940a72f1SSebastian Grimberg     return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_a100 : sgemm_tn_a100) : ((trans_A == 'n') ? dgemm_nn_a100 : dgemm_tn_a100);
43f80f4a74SSebastian Grimberg   } else {
44f80f4a74SSebastian Grimberg     // sm70 or older
45*940a72f1SSebastian Grimberg     return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_v100 : sgemm_tn_v100) : ((trans_A == 'n') ? dgemm_nn_v100 : dgemm_tn_v100);
46*940a72f1SSebastian Grimberg   }
47f80f4a74SSebastian Grimberg }
48f80f4a74SSebastian Grimberg #endif
49f80f4a74SSebastian Grimberg 
50f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
51f80f4a74SSebastian Grimberg void gemm_selector(int gpu_arch, char precision, char trans_A, int m, int n, int k, int *n_batch, int *use_magma) {
52*940a72f1SSebastian Grimberg   const auto &data = gemm_selector_get_data(gpu_arch, precision, trans_A);
53f80f4a74SSebastian Grimberg   int         ir   = -1;
54f80f4a74SSebastian Grimberg   double      norm = std::numeric_limits<double>::max();
55*940a72f1SSebastian Grimberg 
56*940a72f1SSebastian Grimberg   for (size_t i = 0; i < data.size(); i++) {
57*940a72f1SSebastian Grimberg     const int &im = data[i][M_INDEX];
58*940a72f1SSebastian Grimberg     const int &in = data[i][N_INDEX];
59*940a72f1SSebastian Grimberg     const int &ik = data[i][K_INDEX];
60f80f4a74SSebastian Grimberg 
61f80f4a74SSebastian Grimberg     double mdiff = (double)(im - m);
62f80f4a74SSebastian Grimberg     double ndiff = (double)(in - n);
63f80f4a74SSebastian Grimberg     double kdiff = (double)(ik - k);
64*940a72f1SSebastian Grimberg     double nrm   = mdiff * mdiff + ndiff * ndiff + kdiff * kdiff;
65f80f4a74SSebastian Grimberg 
66f80f4a74SSebastian Grimberg     if (nrm < norm) {
67f80f4a74SSebastian Grimberg       norm = nrm;
68f80f4a74SSebastian Grimberg       ir   = i;
69f80f4a74SSebastian Grimberg     }
70f80f4a74SSebastian Grimberg 
71*940a72f1SSebastian Grimberg     if (im == m && in == n && ik == k) {
72*940a72f1SSebastian Grimberg       // The input (m, n, k) exactly matches a record in `data`, no need to search further
73f80f4a74SSebastian Grimberg       break;
74f80f4a74SSebastian Grimberg     }
75f80f4a74SSebastian Grimberg   }
76f80f4a74SSebastian Grimberg 
77f80f4a74SSebastian Grimberg   if (ir >= 0) {
78*940a72f1SSebastian Grimberg     // If the closest match indicates that n = n_batch, that means calling the regular non-batch GEMM.
79*940a72f1SSebastian Grimberg     // So n_batch is set to n instead of the 'n_batch' entry of the matching record.
80*940a72f1SSebastian Grimberg     int n_       = data[ir][N_INDEX];
81*940a72f1SSebastian Grimberg     int n_batch_ = data[ir][N_BATCH_INDEX];
82f80f4a74SSebastian Grimberg     *n_batch     = (n_ == n_batch_) ? n : n_batch_;
83*940a72f1SSebastian Grimberg     *use_magma   = data[ir][USE_MAGMA_INDEX];
84*940a72f1SSebastian Grimberg   } else {
85*940a72f1SSebastian Grimberg     *n_batch   = n;
86*940a72f1SSebastian Grimberg     *use_magma = 0;
87f80f4a74SSebastian Grimberg   }
88f80f4a74SSebastian Grimberg }
89f80f4a74SSebastian Grimberg 
90*940a72f1SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////
91f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP
92*940a72f1SSebastian Grimberg static inline auto nontensor_rtc_get_data(int gpu_arch, char trans_A, int q_comp) -> decltype(dinterp_n_mi250x) {
93*940a72f1SSebastian Grimberg   if (q_comp == 1) {
94*940a72f1SSebastian Grimberg     return (trans_A == 'n') ? dinterp_n_mi250x : dinterp_t_mi250x;
95*940a72f1SSebastian Grimberg   } else {
96*940a72f1SSebastian Grimberg     return (trans_A == 'n') ? dgrad_n_mi250x : dgrad_t_mi250x;
97*940a72f1SSebastian Grimberg   }
98f80f4a74SSebastian Grimberg }
99f80f4a74SSebastian Grimberg #else
100*940a72f1SSebastian Grimberg static inline auto nontensor_rtc_get_data(int gpu_arch, char trans_A, int q_comp) -> decltype(dinterp_n_a100) {
101*940a72f1SSebastian Grimberg   if (q_comp == 1) {
102*940a72f1SSebastian Grimberg     return (trans_A == 'n') ? dinterp_n_a100 : dinterp_t_a100;
103*940a72f1SSebastian Grimberg   } else {
104*940a72f1SSebastian Grimberg     return (trans_A == 'n') ? dgrad_n_a100 : dgrad_t_a100;
105*940a72f1SSebastian Grimberg   }
106f80f4a74SSebastian Grimberg }
107f80f4a74SSebastian Grimberg #endif
108f80f4a74SSebastian Grimberg 
109f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
110*940a72f1SSebastian Grimberg CeedInt nontensor_rtc_get_nb(int gpu_arch, char trans_A, int q_comp, int P, int Q, int n) {
111*940a72f1SSebastian Grimberg   const auto &data = nontensor_rtc_get_data(gpu_arch, trans_A, q_comp);
112f80f4a74SSebastian Grimberg   int         ir   = -1;
113f80f4a74SSebastian Grimberg   double      norm = std::numeric_limits<double>::max();
114*940a72f1SSebastian Grimberg   CeedInt     m    = (trans_A == 'n') ? Q : P;
115*940a72f1SSebastian Grimberg   CeedInt     k    = (trans_A == 'n') ? P : Q;
116f80f4a74SSebastian Grimberg 
117*940a72f1SSebastian Grimberg   for (size_t i = 0; i < data.size(); i++) {
118*940a72f1SSebastian Grimberg     const int &im = data[i][M_INDEX_RTC];
119*940a72f1SSebastian Grimberg     const int &in = data[i][N_INDEX_RTC];
120*940a72f1SSebastian Grimberg     const int &ik = data[i][K_INDEX_RTC];
121*940a72f1SSebastian Grimberg 
122*940a72f1SSebastian Grimberg     double mdiff = (double)(im - m);
123*940a72f1SSebastian Grimberg     double ndiff = (double)(in - n);
124*940a72f1SSebastian Grimberg     double kdiff = (double)(ik - k);
125*940a72f1SSebastian Grimberg     double nrm   = mdiff * mdiff + ndiff * ndiff + kdiff * kdiff;
126f80f4a74SSebastian Grimberg 
127f80f4a74SSebastian Grimberg     if (nrm < norm) {
128f80f4a74SSebastian Grimberg       norm = nrm;
129f80f4a74SSebastian Grimberg       ir   = i;
130f80f4a74SSebastian Grimberg     }
131f80f4a74SSebastian Grimberg 
132*940a72f1SSebastian Grimberg     if (im == m && in == n && ik == k) {
133*940a72f1SSebastian Grimberg       // The input (m, n, k) exactly matches a record in `data`, no need to search further
134f80f4a74SSebastian Grimberg       break;
135f80f4a74SSebastian Grimberg     }
136f80f4a74SSebastian Grimberg   }
137f80f4a74SSebastian Grimberg 
138*940a72f1SSebastian Grimberg   return (ir >= 0) ? data[ir][NB_INDEX_RTC] : 1;
139f80f4a74SSebastian Grimberg }
140