xref: /libCEED/rust/libceed-sys/c-src/backends/magma/ceed-magma-gemm-selector.cpp (revision f80f4a748154eed4bc661c135f695b92b1bc45b9)
1*f80f4a74SSebastian Grimberg #include <stdio.h>
2*f80f4a74SSebastian Grimberg #include <sys/time.h>
3*f80f4a74SSebastian Grimberg 
4*f80f4a74SSebastian Grimberg #include <array>
5*f80f4a74SSebastian Grimberg #include <limits>
6*f80f4a74SSebastian Grimberg #include <vector>
7*f80f4a74SSebastian Grimberg 
8*f80f4a74SSebastian Grimberg #include "ceed-magma.h"
9*f80f4a74SSebastian Grimberg #include "tuning/indices.h"
10*f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP
11*f80f4a74SSebastian Grimberg #include "tuning/mi100.h"
12*f80f4a74SSebastian Grimberg #include "tuning/mi250x.h"
13*f80f4a74SSebastian Grimberg #include "tuning/mi250x_grad_rtc.h"
14*f80f4a74SSebastian Grimberg #include "tuning/mi250x_interp_rtc.h"
15*f80f4a74SSebastian Grimberg #else
16*f80f4a74SSebastian Grimberg #include "tuning/a100.h"
17*f80f4a74SSebastian Grimberg #include "tuning/a100_grad_rtc.h"
18*f80f4a74SSebastian Grimberg #include "tuning/a100_interp_rtc.h"
19*f80f4a74SSebastian Grimberg #include "tuning/v100.h"
20*f80f4a74SSebastian Grimberg #endif
21*f80f4a74SSebastian Grimberg 
22*f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
23*f80f4a74SSebastian Grimberg static void *gemm_selector_get_data(int gpu_arch, char precision, char trans_A) {
24*f80f4a74SSebastian Grimberg // a default
25*f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP
26*f80f4a74SSebastian Grimberg   void *data = (void *)&sgemm_nn_mi250x;
27*f80f4a74SSebastian Grimberg #else
28*f80f4a74SSebastian Grimberg   void *data = (void *)&sgemm_nn_a100;
29*f80f4a74SSebastian Grimberg #endif
30*f80f4a74SSebastian Grimberg 
31*f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP
32*f80f4a74SSebastian Grimberg   if (gpu_arch >= 910) {
33*f80f4a74SSebastian Grimberg     // gfx90a or newer
34*f80f4a74SSebastian Grimberg     data = (precision == 's') ? ((trans_A == 'n') ? (void *)&sgemm_nn_mi250x : (void *)&sgemm_tn_mi250x)
35*f80f4a74SSebastian Grimberg                               : ((trans_A == 'n') ? (void *)&dgemm_nn_mi250x : (void *)&dgemm_tn_mi250x);
36*f80f4a74SSebastian Grimberg   } else {
37*f80f4a74SSebastian Grimberg     // gfx908 or older
38*f80f4a74SSebastian Grimberg     data = (precision == 's') ? ((trans_A == 'n') ? (void *)&sgemm_nn_mi100 : (void *)&sgemm_tn_mi100)
39*f80f4a74SSebastian Grimberg                               : ((trans_A == 'n') ? (void *)&dgemm_nn_mi100 : (void *)&dgemm_tn_mi100);
40*f80f4a74SSebastian Grimberg   }
41*f80f4a74SSebastian Grimberg #else
42*f80f4a74SSebastian Grimberg   if (gpu_arch >= 800) {
43*f80f4a74SSebastian Grimberg     // sm80 or newer
44*f80f4a74SSebastian Grimberg     data = (precision == 's') ? ((trans_A == 'n') ? (void *)&sgemm_nn_a100 : (void *)&sgemm_tn_a100)
45*f80f4a74SSebastian Grimberg                               : ((trans_A == 'n') ? (void *)&dgemm_nn_a100 : (void *)&dgemm_tn_a100);
46*f80f4a74SSebastian Grimberg   } else {
47*f80f4a74SSebastian Grimberg     // sm70 or older
48*f80f4a74SSebastian Grimberg     data = (precision == 's') ? ((trans_A == 'n') ? (void *)&sgemm_nn_v100 : (void *)&sgemm_tn_v100)
49*f80f4a74SSebastian Grimberg                               : ((trans_A == 'n') ? (void *)&dgemm_nn_v100 : (void *)&dgemm_tn_v100);
50*f80f4a74SSebastian Grimberg   }
51*f80f4a74SSebastian Grimberg #endif
52*f80f4a74SSebastian Grimberg 
53*f80f4a74SSebastian Grimberg   return data;
54*f80f4a74SSebastian Grimberg }
55*f80f4a74SSebastian Grimberg 
56*f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
57*f80f4a74SSebastian Grimberg void gemm_selector(int gpu_arch, char precision, char trans_A, int m, int n, int k, int *n_batch, int *use_magma) {
58*f80f4a74SSebastian Grimberg   // defaults
59*f80f4a74SSebastian Grimberg   *n_batch                                           = n;
60*f80f4a74SSebastian Grimberg   *use_magma                                         = 0;
61*f80f4a74SSebastian Grimberg   std::vector<std::array<int, RECORD_LENGTH> > *data = NULL;
62*f80f4a74SSebastian Grimberg   data = (std::vector<std::array<int, RECORD_LENGTH> > *)gemm_selector_get_data(gpu_arch, precision, trans_A);
63*f80f4a74SSebastian Grimberg 
64*f80f4a74SSebastian Grimberg   int    ir   = -1;
65*f80f4a74SSebastian Grimberg   double norm = std::numeric_limits<double>::max();
66*f80f4a74SSebastian Grimberg   for (size_t i = 0; i < data->size(); i++) {
67*f80f4a74SSebastian Grimberg     int im = (*data)[i][M_INDEX];
68*f80f4a74SSebastian Grimberg     int in = (*data)[i][N_INDEX];
69*f80f4a74SSebastian Grimberg     int ik = (*data)[i][K_INDEX];
70*f80f4a74SSebastian Grimberg 
71*f80f4a74SSebastian Grimberg     double mdiff = (double)(im - m);
72*f80f4a74SSebastian Grimberg     double ndiff = (double)(in - n);
73*f80f4a74SSebastian Grimberg     double kdiff = (double)(ik - k);
74*f80f4a74SSebastian Grimberg 
75*f80f4a74SSebastian Grimberg     double nrm = sqrt(mdiff * mdiff + ndiff * ndiff + kdiff * kdiff);
76*f80f4a74SSebastian Grimberg 
77*f80f4a74SSebastian Grimberg     if (nrm < norm) {
78*f80f4a74SSebastian Grimberg       norm = nrm;
79*f80f4a74SSebastian Grimberg       ir   = i;
80*f80f4a74SSebastian Grimberg     }
81*f80f4a74SSebastian Grimberg 
82*f80f4a74SSebastian Grimberg     if (nrm == 0) {
83*f80f4a74SSebastian Grimberg       // the input (m, n, k) exactly matches a record in `data`
84*f80f4a74SSebastian Grimberg       // no need to search further
85*f80f4a74SSebastian Grimberg       break;
86*f80f4a74SSebastian Grimberg     }
87*f80f4a74SSebastian Grimberg   }
88*f80f4a74SSebastian Grimberg 
89*f80f4a74SSebastian Grimberg   if (ir >= 0) {
90*f80f4a74SSebastian Grimberg     *use_magma = (*data)[ir][USE_MAGMA_INDEX];
91*f80f4a74SSebastian Grimberg     // if the closest match indicates that n = n_batch,
92*f80f4a74SSebastian Grimberg     // that means calling the regular non-batch gemm.
93*f80f4a74SSebastian Grimberg     // So n_batch is set to n instead of the 'n_batch'
94*f80f4a74SSebastian Grimberg     // entry of the matching record
95*f80f4a74SSebastian Grimberg     int n_       = (*data)[ir][N_INDEX];
96*f80f4a74SSebastian Grimberg     int n_batch_ = (*data)[ir][N_BATCH_INDEX];
97*f80f4a74SSebastian Grimberg     *n_batch     = (n_ == n_batch_) ? n : n_batch_;
98*f80f4a74SSebastian Grimberg   }
99*f80f4a74SSebastian Grimberg }
100*f80f4a74SSebastian Grimberg 
101*f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
102*f80f4a74SSebastian Grimberg static void *nontensor_rtc_get_data(int gpu_arch, char precision, CeedEvalMode e_mode, CeedTransposeMode t_mode) {
103*f80f4a74SSebastian Grimberg // a default
104*f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP
105*f80f4a74SSebastian Grimberg   void *data = (void *)&dinterp_n_mi250x;
106*f80f4a74SSebastian Grimberg #else
107*f80f4a74SSebastian Grimberg   void *data = (void *)&dinterp_n_a100;
108*f80f4a74SSebastian Grimberg #endif
109*f80f4a74SSebastian Grimberg 
110*f80f4a74SSebastian Grimberg #ifdef CEED_MAGMA_USE_HIP
111*f80f4a74SSebastian Grimberg   if (e_mode == CEED_EVAL_INTERP) {
112*f80f4a74SSebastian Grimberg     data = (t_mode == CEED_TRANSPOSE) ? (void *)&dinterp_t_mi250x : (void *)&dinterp_n_mi250x;
113*f80f4a74SSebastian Grimberg   } else if (e_mode == CEED_EVAL_GRAD) {
114*f80f4a74SSebastian Grimberg     data = (t_mode == CEED_TRANSPOSE) ? (void *)&dgrad_t_mi250x : (void *)&dgrad_n_mi250x;
115*f80f4a74SSebastian Grimberg   }
116*f80f4a74SSebastian Grimberg #else
117*f80f4a74SSebastian Grimberg   if (e_mode == CEED_EVAL_INTERP) {
118*f80f4a74SSebastian Grimberg     data = (t_mode == CEED_TRANSPOSE) ? (void *)&dinterp_t_a100 : (void *)&dinterp_n_a100;
119*f80f4a74SSebastian Grimberg   } else if (e_mode == CEED_EVAL_GRAD) {
120*f80f4a74SSebastian Grimberg     data = (t_mode == CEED_TRANSPOSE) ? (void *)&dgrad_t_a100 : (void *)&dgrad_n_a100;
121*f80f4a74SSebastian Grimberg   }
122*f80f4a74SSebastian Grimberg #endif
123*f80f4a74SSebastian Grimberg 
124*f80f4a74SSebastian Grimberg   return data;
125*f80f4a74SSebastian Grimberg }
126*f80f4a74SSebastian Grimberg 
127*f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
128*f80f4a74SSebastian Grimberg CeedInt nontensor_rtc_get_nb(int gpu_arch, char precision, CeedEvalMode e_mode, CeedTransposeMode t_mode, int P_, int N, int Q_) {
129*f80f4a74SSebastian Grimberg   CeedInt P  = (t_mode == CEED_TRANSPOSE) ? P_ : Q_;
130*f80f4a74SSebastian Grimberg   CeedInt Q  = (t_mode == CEED_TRANSPOSE) ? Q_ : P_;
131*f80f4a74SSebastian Grimberg   CeedInt NB = 1;
132*f80f4a74SSebastian Grimberg 
133*f80f4a74SSebastian Grimberg   std::vector<std::array<int, RECORD_LENGTH_RTC> > *data = NULL;
134*f80f4a74SSebastian Grimberg   data = (std::vector<std::array<int, RECORD_LENGTH_RTC> > *)nontensor_rtc_get_data(gpu_arch, precision, e_mode, t_mode);
135*f80f4a74SSebastian Grimberg 
136*f80f4a74SSebastian Grimberg   int    ir   = -1;
137*f80f4a74SSebastian Grimberg   double norm = std::numeric_limits<double>::max();
138*f80f4a74SSebastian Grimberg   for (size_t i = 0; i < data->size(); i++) {
139*f80f4a74SSebastian Grimberg     int ip = (*data)[i][M_INDEX_RTC];
140*f80f4a74SSebastian Grimberg     int in = (*data)[i][N_INDEX_RTC];
141*f80f4a74SSebastian Grimberg     int iq = (*data)[i][K_INDEX_RTC];
142*f80f4a74SSebastian Grimberg 
143*f80f4a74SSebastian Grimberg     double pdiff = (double)(ip - P);
144*f80f4a74SSebastian Grimberg     double ndiff = (double)(in - N);
145*f80f4a74SSebastian Grimberg     double qdiff = (double)(iq - Q);
146*f80f4a74SSebastian Grimberg     double nrm   = sqrt(pdiff * pdiff + ndiff * ndiff + qdiff * qdiff);
147*f80f4a74SSebastian Grimberg 
148*f80f4a74SSebastian Grimberg     if (nrm < norm) {
149*f80f4a74SSebastian Grimberg       norm = nrm;
150*f80f4a74SSebastian Grimberg       ir   = i;
151*f80f4a74SSebastian Grimberg     }
152*f80f4a74SSebastian Grimberg 
153*f80f4a74SSebastian Grimberg     if (nrm == 0) {
154*f80f4a74SSebastian Grimberg       // the input (m, n, k) exactly matches a record in `data`
155*f80f4a74SSebastian Grimberg       // no need to search further
156*f80f4a74SSebastian Grimberg       break;
157*f80f4a74SSebastian Grimberg     }
158*f80f4a74SSebastian Grimberg   }
159*f80f4a74SSebastian Grimberg 
160*f80f4a74SSebastian Grimberg   if (ir >= 0) {
161*f80f4a74SSebastian Grimberg     NB = (*data)[ir][NB_INDEX_RTC];
162*f80f4a74SSebastian Grimberg   }
163*f80f4a74SSebastian Grimberg 
164*f80f4a74SSebastian Grimberg   return NB;
165*f80f4a74SSebastian Grimberg }
166