xref: /libCEED/backends/avx/ceed-avx-tensor.c (revision 778419476db2fdf3952b1b9a5faed3f766f76223)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed/ceed.h>
18 #include <ceed/backend.h>
19 #include <immintrin.h>
20 #include <stdbool.h>
21 #include "ceed-avx.h"
22 
23 // c += a * b
24 #ifdef __FMA__
25 #  define fmadd(c,a,b) (c) = _mm256_fmadd_pd((a), (b), (c))
26 #else
27 #  define fmadd(c,a,b) (c) += _mm256_mul_pd((a), (b))
28 #endif
29 
30 //------------------------------------------------------------------------------
31 // Blocked Tensor Contract
32 //------------------------------------------------------------------------------
33 static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract,
34     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
35     CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u,
36     CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
37   CeedInt t_stride_0 = B, t_stride_1 = 1;
38   if (t_mode == CEED_TRANSPOSE) {
39     t_stride_0 = 1; t_stride_1 = J;
40   }
41 
42   for (CeedInt a=0; a<A; a++) {
43     // Blocks of 4 rows
44     for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) {
45       for (CeedInt c=0; c<(C/CC)*CC; c+=CC) {
46         __m256d vv[JJ][CC/4]; // Output tile to be held in registers
47         for (CeedInt jj=0; jj<JJ; jj++)
48           for (CeedInt cc=0; cc<CC/4; cc++)
49             vv[jj][cc] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c+cc*4]);
50 
51         for (CeedInt b=0; b<B; b++) {
52           for (CeedInt jj=0; jj<JJ; jj++) { // unroll
53             __m256d tqv = _mm256_set1_pd(t[(j+jj)*t_stride_0 + b*t_stride_1]);
54             for (CeedInt cc=0; cc<CC/4; cc++) // unroll
55               fmadd(vv[jj][cc], tqv, _mm256_loadu_pd(&u[(a*B+b)*C+c+cc*4]));
56           }
57         }
58         for (CeedInt jj=0; jj<JJ; jj++)
59           for (CeedInt cc=0; cc<CC/4; cc++)
60             _mm256_storeu_pd(&v[(a*J+j+jj)*C+c+cc*4], vv[jj][cc]);
61       }
62     }
63     // Remainder of rows
64     CeedInt j=(J/JJ)*JJ;
65     if (j < J) {
66       for (CeedInt c=0; c<(C/CC)*CC; c+=CC) {
67         __m256d vv[JJ][CC/4]; // Output tile to be held in registers
68         for (CeedInt jj=0; jj<J-j; jj++)
69           for (CeedInt cc=0; cc<CC/4; cc++)
70             vv[jj][cc] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c+cc*4]);
71 
72         for (CeedInt b=0; b<B; b++) {
73           for (CeedInt jj=0; jj<J-j; jj++) { // doesn't unroll
74             __m256d tqv = _mm256_set1_pd(t[(j+jj)*t_stride_0 + b*t_stride_1]);
75             for (CeedInt cc=0; cc<CC/4; cc++) // unroll
76               fmadd(vv[jj][cc], tqv, _mm256_loadu_pd(&u[(a*B+b)*C+c+cc*4]));
77           }
78         }
79         for (CeedInt jj=0; jj<J-j; jj++)
80           for (CeedInt cc=0; cc<CC/4; cc++)
81             _mm256_storeu_pd(&v[(a*J+j+jj)*C+c+cc*4], vv[jj][cc]);
82       }
83     }
84   }
85   return CEED_ERROR_SUCCESS;
86 }
87 
88 //------------------------------------------------------------------------------
89 // Serial Tensor Contract Remainder
90 //------------------------------------------------------------------------------
91 static inline int CeedTensorContract_Avx_Remainder(CeedTensorContract contract,
92     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
93     CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u,
94     CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
95   CeedInt t_stride_0 = B, t_stride_1 = 1;
96   if (t_mode == CEED_TRANSPOSE) {
97     t_stride_0 = 1; t_stride_1 = J;
98   }
99 
100   CeedInt J_break = J%JJ ? (J/JJ)*JJ : (J/JJ-1)*JJ;
101   for (CeedInt a=0; a<A; a++) {
102     // Blocks of 4 columns
103     for (CeedInt c = (C/CC)*CC; c<C; c+=4) {
104       // Blocks of 4 rows
105       for (CeedInt j=0; j<J_break; j+=JJ) {
106         __m256d vv[JJ]; // Output tile to be held in registers
107         for (CeedInt jj=0; jj<JJ; jj++)
108           vv[jj] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c]);
109 
110         for (CeedInt b=0; b<B; b++) {
111           __m256d tqu;
112           if (C-c == 1)
113             tqu = _mm256_set_pd(0.0, 0.0, 0.0, u[(a*B+b)*C+c+0]);
114           else if (C-c == 2)
115             tqu = _mm256_set_pd(0.0, 0.0, u[(a*B+b)*C+c+1],
116                                 u[(a*B+b)*C+c+0]);
117           else if (C-c == 3)
118             tqu = _mm256_set_pd(0.0, u[(a*B+b)*C+c+2], u[(a*B+b)*C+c+1],
119                                 u[(a*B+b)*C+c+0]);
120           else
121             tqu = _mm256_loadu_pd(&u[(a*B+b)*C+c]);
122           for (CeedInt jj=0; jj<JJ; jj++) // unroll
123             fmadd(vv[jj], tqu, _mm256_set1_pd(t[(j+jj)*t_stride_0 + b*t_stride_1]));
124         }
125         for (CeedInt jj=0; jj<JJ; jj++)
126           _mm256_storeu_pd(&v[(a*J+j+jj)*C+c], vv[jj]);
127       }
128     }
129     // Remainder of rows, all columns
130     for (CeedInt j=J_break; j<J; j++)
131       for (CeedInt b=0; b<B; b++) {
132         CeedScalar tq = t[j*t_stride_0 + b*t_stride_1];
133         for (CeedInt c=(C/CC)*CC; c<C; c++)
134           v[(a*J+j)*C+c] += tq * u[(a*B+b)*C+c];
135       }
136   }
137   return CEED_ERROR_SUCCESS;
138 }
139 
140 //------------------------------------------------------------------------------
141 // Serial Tensor Contract C=1
142 //------------------------------------------------------------------------------
143 static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract,
144     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
145     CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u,
146     CeedScalar *restrict v, const CeedInt AA, const CeedInt JJ) {
147   CeedInt t_stride_0 = B, t_stride_1 = 1;
148   if (t_mode == CEED_TRANSPOSE) {
149     t_stride_0 = 1; t_stride_1 = J;
150   }
151 
152   // Blocks of 4 rows
153   for (CeedInt a=0; a<(A/AA)*AA; a+=AA) {
154     for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) {
155       __m256d vv[AA][JJ/4]; // Output tile to be held in registers
156       for (CeedInt aa=0; aa<AA; aa++)
157         for (CeedInt jj=0; jj<JJ/4; jj++)
158           vv[aa][jj] = _mm256_loadu_pd(&v[(a+aa)*J+j+jj*4]);
159 
160       for (CeedInt b=0; b<B; b++) {
161         for (CeedInt jj=0; jj<JJ/4; jj++) { // unroll
162           __m256d tqv = _mm256_set_pd(t[(j+jj*4+3)*t_stride_0 + b*t_stride_1],
163                                       t[(j+jj*4+2)*t_stride_0 + b*t_stride_1],
164                                       t[(j+jj*4+1)*t_stride_0 + b*t_stride_1],
165                                       t[(j+jj*4+0)*t_stride_0 + b*t_stride_1]);
166           for (CeedInt aa=0; aa<AA; aa++) // unroll
167             fmadd(vv[aa][jj], tqv, _mm256_set1_pd(u[(a+aa)*B+b]));
168         }
169       }
170       for (CeedInt aa=0; aa<AA; aa++)
171         for (CeedInt jj=0; jj<JJ/4; jj++)
172           _mm256_storeu_pd(&v[(a+aa)*J+j+jj*4], vv[aa][jj]);
173     }
174   }
175   // Remainder of rows
176   CeedInt a=(A/AA)*AA;
177   for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) {
178     __m256d vv[AA][JJ/4]; // Output tile to be held in registers
179     for (CeedInt aa=0; aa<A-a; aa++)
180       for (CeedInt jj=0; jj<JJ/4; jj++)
181         vv[aa][jj] = _mm256_loadu_pd(&v[(a+aa)*J+j+jj*4]);
182 
183     for (CeedInt b=0; b<B; b++) {
184       for (CeedInt jj=0; jj<JJ/4; jj++) { // unroll
185         __m256d tqv = _mm256_set_pd(t[(j+jj*4+3)*t_stride_0 + b*t_stride_1],
186                                     t[(j+jj*4+2)*t_stride_0 + b*t_stride_1],
187                                     t[(j+jj*4+1)*t_stride_0 + b*t_stride_1],
188                                     t[(j+jj*4+0)*t_stride_0 + b*t_stride_1]);
189         for (CeedInt aa=0; aa<A-a; aa++) // unroll
190           fmadd(vv[aa][jj], tqv, _mm256_set1_pd(u[(a+aa)*B+b]));
191       }
192     }
193     for (CeedInt aa=0; aa<A-a; aa++)
194       for (CeedInt jj=0; jj<JJ/4; jj++)
195         _mm256_storeu_pd(&v[(a+aa)*J+j+jj*4], vv[aa][jj]);
196   }
197   // Column remainder
198   CeedInt A_break = A%AA ? (A/AA)*AA : (A/AA-1)*AA;
199   // Blocks of 4 columns
200   for (CeedInt j = (J/JJ)*JJ; j<J; j+=4) {
201     // Blocks of 4 rows
202     for (CeedInt a=0; a<A_break; a+=AA) {
203       __m256d vv[AA]; // Output tile to be held in registers
204       for (CeedInt aa=0; aa<AA; aa++)
205         vv[aa] = _mm256_loadu_pd(&v[(a+aa)*J+j]);
206 
207       for (CeedInt b=0; b<B; b++) {
208         __m256d tqv;
209         if (J-j == 1)
210           tqv = _mm256_set_pd(0.0, 0.0, 0.0, t[(j+0)*t_stride_0 + b*t_stride_1]);
211         else if (J-j == 2)
212           tqv = _mm256_set_pd(0.0, 0.0, t[(j+1)*t_stride_0 + b*t_stride_1],
213                               t[(j+0)*t_stride_0 + b*t_stride_1]);
214         else if (J-3 == j)
215           tqv = _mm256_set_pd(0.0, t[(j+2)*t_stride_0 + b*t_stride_1],
216                               t[(j+1)*t_stride_0 + b*t_stride_1],
217                               t[(j+0)*t_stride_0 + b*t_stride_1]);
218         else
219           tqv = _mm256_set_pd(t[(j+3)*t_stride_0 + b*t_stride_1],
220                               t[(j+2)*t_stride_0 + b*t_stride_1],
221                               t[(j+1)*t_stride_0 + b*t_stride_1],
222                               t[(j+0)*t_stride_0 + b*t_stride_1]);
223         for (CeedInt aa=0; aa<AA; aa++) // unroll
224           fmadd(vv[aa], tqv, _mm256_set1_pd(u[(a+aa)*B+b]));
225       }
226       for (CeedInt aa=0; aa<AA; aa++)
227         _mm256_storeu_pd(&v[(a+aa)*J+j], vv[aa]);
228     }
229   }
230   // Remainder of rows, all columns
231   for (CeedInt b=0; b<B; b++) {
232     for (CeedInt j=(J/JJ)*JJ; j<J; j++) {
233       CeedScalar tq = t[j*t_stride_0 + b*t_stride_1];
234       for (CeedInt a=A_break; a<A; a++)
235         v[a*J+j] += tq * u[a*B+b];
236     }
237   }
238   return CEED_ERROR_SUCCESS;
239 }
240 
241 //------------------------------------------------------------------------------
242 // Tensor Contract - Common Sizes
243 //------------------------------------------------------------------------------
244 static int CeedTensorContract_Avx_Blocked_4_8(CeedTensorContract contract,
245     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
246     CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u,
247     CeedScalar *restrict v) {
248   return CeedTensorContract_Avx_Blocked(contract, A, B, C, J, t, t_mode, add, u,
249                                         v, 4, 8);
250 }
251 static int CeedTensorContract_Avx_Remainder_8_8(CeedTensorContract contract,
252     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
253     CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u,
254     CeedScalar *restrict v) {
255   return CeedTensorContract_Avx_Remainder(contract, A, B, C, J, t, t_mode, add,
256                                           u, v, 8, 8);
257 }
258 static int CeedTensorContract_Avx_Single_4_8(CeedTensorContract contract,
259     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
260     CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u,
261     CeedScalar *restrict v) {
262   return CeedTensorContract_Avx_Single(contract, A, B, C, J, t, t_mode, add, u,
263                                        v, 4, 8);
264 }
265 
266 //------------------------------------------------------------------------------
267 // Tensor Contract Apply
268 //------------------------------------------------------------------------------
269 static int CeedTensorContractApply_Avx(CeedTensorContract contract, CeedInt A,
270                                        CeedInt B, CeedInt C, CeedInt J,
271                                        const CeedScalar *restrict t,
272                                        CeedTransposeMode t_mode,
273                                        const CeedInt add,
274                                        const CeedScalar *restrict u,
275                                        CeedScalar *restrict v) {
276   const CeedInt blk_size = 8;
277 
278   if (!add)
279     for (CeedInt q=0; q<A*J*C; q++)
280       v[q] = (CeedScalar) 0.0;
281 
282   if (C == 1) {
283     // Serial C=1 Case
284     CeedTensorContract_Avx_Single_4_8(contract, A, B, C, J, t, t_mode, true, u,
285                                       v);
286   } else {
287     // Blocks of 8 columns
288     if (C >= blk_size)
289       CeedTensorContract_Avx_Blocked_4_8(contract, A, B, C, J, t, t_mode, true,
290                                          u, v);
291     // Remainder of columns
292     if (C % blk_size)
293       CeedTensorContract_Avx_Remainder_8_8(contract, A, B, C, J, t, t_mode, true,
294                                            u, v);
295   }
296 
297   return CEED_ERROR_SUCCESS;
298 }
299 
300 //------------------------------------------------------------------------------
301 // Tensor Contract Destroy
302 //------------------------------------------------------------------------------
303 static int CeedTensorContractDestroy_Avx(CeedTensorContract contract) {
304   return CEED_ERROR_SUCCESS;
305 }
306 
307 //------------------------------------------------------------------------------
308 // Tensor Contract Create
309 //------------------------------------------------------------------------------
310 int CeedTensorContractCreate_Avx(CeedBasis basis, CeedTensorContract contract) {
311   int ierr;
312   Ceed ceed;
313   ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChkBackend(ierr);
314 
315   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply",
316                                 CeedTensorContractApply_Avx); CeedChkBackend(ierr);
317   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy",
318                                 CeedTensorContractDestroy_Avx); CeedChkBackend(ierr);
319 
320   return CEED_ERROR_SUCCESS;
321 }
322 //------------------------------------------------------------------------------
323