xref: /libCEED/backends/magma/ceed-magma-gemm-selector.cpp (revision 715f9ba89a309f24226005ca1fbb9f59fe9eac68)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <array>
9 #include <limits>
10 #include <vector>
11 
12 #include "ceed-magma-gemm-selector.h"
13 
14 #include "tuning/indices.h"
15 #ifdef CEED_MAGMA_USE_HIP
16 #include "tuning/mi100.h"
17 #include "tuning/mi250x.h"
18 #include "tuning/mi250x_grad_rtc.h"
19 #include "tuning/mi250x_interp_rtc.h"
20 #else
21 #include "tuning/a100.h"
22 #include "tuning/a100_grad_rtc.h"
23 #include "tuning/a100_interp_rtc.h"
24 #include "tuning/v100.h"
25 #endif
26 
27 ////////////////////////////////////////////////////////////////////////////////
28 #ifdef CEED_MAGMA_USE_HIP
29 static inline auto gemm_selector_get_data(int gpu_arch, char precision, char trans_A) -> decltype(dgemm_nn_mi250x) {
30   if (gpu_arch >= 910) {
31     // gfx90a or newer
32     return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_mi250x : sgemm_tn_mi250x) : ((trans_A == 'n') ? dgemm_nn_mi250x : dgemm_tn_mi250x);
33   } else {
34     // gfx908 or older
35     return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_mi100 : sgemm_tn_mi100) : ((trans_A == 'n') ? dgemm_nn_mi100 : dgemm_tn_mi100);
36   }
37 }
38 #else
39 static inline auto gemm_selector_get_data(int gpu_arch, char precision, char trans_A) -> decltype(dgemm_nn_a100) {
40   if (gpu_arch >= 800) {
41     // sm80 or newer
42     return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_a100 : sgemm_tn_a100) : ((trans_A == 'n') ? dgemm_nn_a100 : dgemm_tn_a100);
43   } else {
44     // sm70 or older
45     return (precision == 's') ? ((trans_A == 'n') ? sgemm_nn_v100 : sgemm_tn_v100) : ((trans_A == 'n') ? dgemm_nn_v100 : dgemm_tn_v100);
46   }
47 }
48 #endif
49 
50 ////////////////////////////////////////////////////////////////////////////////
51 void gemm_selector(int gpu_arch, char precision, char trans_A, int m, int n, int k, int *n_batch, int *use_magma) {
52   const auto &data = gemm_selector_get_data(gpu_arch, precision, trans_A);
53   int         ir   = -1;
54   double      norm = std::numeric_limits<double>::max();
55 
56   for (size_t i = 0; i < data.size(); i++) {
57     const int &im = data[i][M_INDEX];
58     const int &in = data[i][N_INDEX];
59     const int &ik = data[i][K_INDEX];
60 
61     double mdiff = (double)(im - m);
62     double ndiff = (double)(in - n);
63     double kdiff = (double)(ik - k);
64     double nrm   = mdiff * mdiff + ndiff * ndiff + kdiff * kdiff;
65 
66     if (nrm < norm) {
67       norm = nrm;
68       ir   = i;
69     }
70 
71     if (im == m && in == n && ik == k) {
72       // The input (m, n, k) exactly matches a record in `data`, no need to search further
73       break;
74     }
75   }
76 
77   if (ir >= 0) {
78     // If the closest match indicates that n = n_batch, that means calling the regular non-batch GEMM.
79     // So n_batch is set to n instead of the 'n_batch' entry of the matching record.
80     int n_       = data[ir][N_INDEX];
81     int n_batch_ = data[ir][N_BATCH_INDEX];
82     *n_batch     = (n_ == n_batch_) ? n : n_batch_;
83     *use_magma   = data[ir][USE_MAGMA_INDEX];
84   } else {
85     *n_batch   = n;
86     *use_magma = 0;
87   }
88 }
89 
90 //////////////////////////////////////////////////////////////////////////////
91 #ifdef CEED_MAGMA_USE_HIP
92 static inline auto nontensor_rtc_get_data(int gpu_arch, char trans_A, int q_comp) -> decltype(dinterp_n_mi250x) {
93   if (q_comp == 1) {
94     return (trans_A == 'n') ? dinterp_n_mi250x : dinterp_t_mi250x;
95   } else {
96     return (trans_A == 'n') ? dgrad_n_mi250x : dgrad_t_mi250x;
97   }
98 }
99 #else
100 static inline auto nontensor_rtc_get_data(int gpu_arch, char trans_A, int q_comp) -> decltype(dinterp_n_a100) {
101   if (q_comp == 1) {
102     return (trans_A == 'n') ? dinterp_n_a100 : dinterp_t_a100;
103   } else {
104     return (trans_A == 'n') ? dgrad_n_a100 : dgrad_t_a100;
105   }
106 }
107 #endif
108 
109 ////////////////////////////////////////////////////////////////////////////////
110 CeedInt nontensor_rtc_get_nb(int gpu_arch, char trans_A, int q_comp, int P, int Q, int n) {
111   const auto &data = nontensor_rtc_get_data(gpu_arch, trans_A, q_comp);
112   int         ir   = -1;
113   double      norm = std::numeric_limits<double>::max();
114   CeedInt     m    = (trans_A == 'n') ? Q : P;
115   CeedInt     k    = (trans_A == 'n') ? P : Q;
116 
117   for (size_t i = 0; i < data.size(); i++) {
118     const int &im = data[i][M_INDEX_RTC];
119     const int &in = data[i][N_INDEX_RTC];
120     const int &ik = data[i][K_INDEX_RTC];
121 
122     double mdiff = (double)(im - m);
123     double ndiff = (double)(in - n);
124     double kdiff = (double)(ik - k);
125     double nrm   = mdiff * mdiff + ndiff * ndiff + kdiff * kdiff;
126 
127     if (nrm < norm) {
128       norm = nrm;
129       ir   = i;
130     }
131 
132     if (im == m && in == n && ik == k) {
133       // The input (m, n, k) exactly matches a record in `data`, no need to search further
134       break;
135     }
136   }
137 
138   return (ir >= 0) ? data[ir][NB_INDEX_RTC] : 1;
139 }
140