xref: /libCEED/rust/libceed-sys/c-src/backends/magma/ceed-magma-gemm-selector.cpp (revision f80f4a748154eed4bc661c135f695b92b1bc45b9)
1 #include <stdio.h>
2 #include <sys/time.h>
3 
4 #include <array>
5 #include <limits>
6 #include <vector>
7 
8 #include "ceed-magma.h"
9 #include "tuning/indices.h"
10 #ifdef CEED_MAGMA_USE_HIP
11 #include "tuning/mi100.h"
12 #include "tuning/mi250x.h"
13 #include "tuning/mi250x_grad_rtc.h"
14 #include "tuning/mi250x_interp_rtc.h"
15 #else
16 #include "tuning/a100.h"
17 #include "tuning/a100_grad_rtc.h"
18 #include "tuning/a100_interp_rtc.h"
19 #include "tuning/v100.h"
20 #endif
21 
22 ////////////////////////////////////////////////////////////////////////////////
23 static void *gemm_selector_get_data(int gpu_arch, char precision, char trans_A) {
24 // a default
25 #ifdef CEED_MAGMA_USE_HIP
26   void *data = (void *)&sgemm_nn_mi250x;
27 #else
28   void *data = (void *)&sgemm_nn_a100;
29 #endif
30 
31 #ifdef CEED_MAGMA_USE_HIP
32   if (gpu_arch >= 910) {
33     // gfx90a or newer
34     data = (precision == 's') ? ((trans_A == 'n') ? (void *)&sgemm_nn_mi250x : (void *)&sgemm_tn_mi250x)
35                               : ((trans_A == 'n') ? (void *)&dgemm_nn_mi250x : (void *)&dgemm_tn_mi250x);
36   } else {
37     // gfx908 or older
38     data = (precision == 's') ? ((trans_A == 'n') ? (void *)&sgemm_nn_mi100 : (void *)&sgemm_tn_mi100)
39                               : ((trans_A == 'n') ? (void *)&dgemm_nn_mi100 : (void *)&dgemm_tn_mi100);
40   }
41 #else
42   if (gpu_arch >= 800) {
43     // sm80 or newer
44     data = (precision == 's') ? ((trans_A == 'n') ? (void *)&sgemm_nn_a100 : (void *)&sgemm_tn_a100)
45                               : ((trans_A == 'n') ? (void *)&dgemm_nn_a100 : (void *)&dgemm_tn_a100);
46   } else {
47     // sm70 or older
48     data = (precision == 's') ? ((trans_A == 'n') ? (void *)&sgemm_nn_v100 : (void *)&sgemm_tn_v100)
49                               : ((trans_A == 'n') ? (void *)&dgemm_nn_v100 : (void *)&dgemm_tn_v100);
50   }
51 #endif
52 
53   return data;
54 }
55 
56 ////////////////////////////////////////////////////////////////////////////////
57 void gemm_selector(int gpu_arch, char precision, char trans_A, int m, int n, int k, int *n_batch, int *use_magma) {
58   // defaults
59   *n_batch                                           = n;
60   *use_magma                                         = 0;
61   std::vector<std::array<int, RECORD_LENGTH> > *data = NULL;
62   data = (std::vector<std::array<int, RECORD_LENGTH> > *)gemm_selector_get_data(gpu_arch, precision, trans_A);
63 
64   int    ir   = -1;
65   double norm = std::numeric_limits<double>::max();
66   for (size_t i = 0; i < data->size(); i++) {
67     int im = (*data)[i][M_INDEX];
68     int in = (*data)[i][N_INDEX];
69     int ik = (*data)[i][K_INDEX];
70 
71     double mdiff = (double)(im - m);
72     double ndiff = (double)(in - n);
73     double kdiff = (double)(ik - k);
74 
75     double nrm = sqrt(mdiff * mdiff + ndiff * ndiff + kdiff * kdiff);
76 
77     if (nrm < norm) {
78       norm = nrm;
79       ir   = i;
80     }
81 
82     if (nrm == 0) {
83       // the input (m, n, k) exactly matches a record in `data`
84       // no need to search further
85       break;
86     }
87   }
88 
89   if (ir >= 0) {
90     *use_magma = (*data)[ir][USE_MAGMA_INDEX];
91     // if the closest match indicates that n = n_batch,
92     // that means calling the regular non-batch gemm.
93     // So n_batch is set to n instead of the 'n_batch'
94     // entry of the matching record
95     int n_       = (*data)[ir][N_INDEX];
96     int n_batch_ = (*data)[ir][N_BATCH_INDEX];
97     *n_batch     = (n_ == n_batch_) ? n : n_batch_;
98   }
99 }
100 
101 ////////////////////////////////////////////////////////////////////////////////
102 static void *nontensor_rtc_get_data(int gpu_arch, char precision, CeedEvalMode e_mode, CeedTransposeMode t_mode) {
103 // a default
104 #ifdef CEED_MAGMA_USE_HIP
105   void *data = (void *)&dinterp_n_mi250x;
106 #else
107   void *data = (void *)&dinterp_n_a100;
108 #endif
109 
110 #ifdef CEED_MAGMA_USE_HIP
111   if (e_mode == CEED_EVAL_INTERP) {
112     data = (t_mode == CEED_TRANSPOSE) ? (void *)&dinterp_t_mi250x : (void *)&dinterp_n_mi250x;
113   } else if (e_mode == CEED_EVAL_GRAD) {
114     data = (t_mode == CEED_TRANSPOSE) ? (void *)&dgrad_t_mi250x : (void *)&dgrad_n_mi250x;
115   }
116 #else
117   if (e_mode == CEED_EVAL_INTERP) {
118     data = (t_mode == CEED_TRANSPOSE) ? (void *)&dinterp_t_a100 : (void *)&dinterp_n_a100;
119   } else if (e_mode == CEED_EVAL_GRAD) {
120     data = (t_mode == CEED_TRANSPOSE) ? (void *)&dgrad_t_a100 : (void *)&dgrad_n_a100;
121   }
122 #endif
123 
124   return data;
125 }
126 
127 ////////////////////////////////////////////////////////////////////////////////
128 CeedInt nontensor_rtc_get_nb(int gpu_arch, char precision, CeedEvalMode e_mode, CeedTransposeMode t_mode, int P_, int N, int Q_) {
129   CeedInt P  = (t_mode == CEED_TRANSPOSE) ? P_ : Q_;
130   CeedInt Q  = (t_mode == CEED_TRANSPOSE) ? Q_ : P_;
131   CeedInt NB = 1;
132 
133   std::vector<std::array<int, RECORD_LENGTH_RTC> > *data = NULL;
134   data = (std::vector<std::array<int, RECORD_LENGTH_RTC> > *)nontensor_rtc_get_data(gpu_arch, precision, e_mode, t_mode);
135 
136   int    ir   = -1;
137   double norm = std::numeric_limits<double>::max();
138   for (size_t i = 0; i < data->size(); i++) {
139     int ip = (*data)[i][M_INDEX_RTC];
140     int in = (*data)[i][N_INDEX_RTC];
141     int iq = (*data)[i][K_INDEX_RTC];
142 
143     double pdiff = (double)(ip - P);
144     double ndiff = (double)(in - N);
145     double qdiff = (double)(iq - Q);
146     double nrm   = sqrt(pdiff * pdiff + ndiff * ndiff + qdiff * qdiff);
147 
148     if (nrm < norm) {
149       norm = nrm;
150       ir   = i;
151     }
152 
153     if (nrm == 0) {
154       // the input (m, n, k) exactly matches a record in `data`
155       // no need to search further
156       break;
157     }
158   }
159 
160   if (ir >= 0) {
161     NB = (*data)[ir][NB_INDEX_RTC];
162   }
163 
164   return NB;
165 }
166