xref: /libCEED/backends/avx/ceed-avx-tensor.c (revision 5f67fade47e323fa44018f277580acfe24400ad4)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include "ceed-avx.h"
18 
19 // c += a * b
20 #ifdef __FMA__
21 #  define fmadd(c,a,b) (c) = _mm256_fmadd_pd((a), (b), (c))
22 #else
23 #  define fmadd(c,a,b) (c) += _mm256_mul_pd((a), (b))
24 #endif
25 
26 //------------------------------------------------------------------------------
27 // Blocked Tensor Contract
28 //------------------------------------------------------------------------------
29 static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract,
30     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
31     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
32     CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
33   CeedInt tstride0 = B, tstride1 = 1;
34   if (tmode == CEED_TRANSPOSE) {
35     tstride0 = 1; tstride1 = J;
36   }
37 
38   for (CeedInt a=0; a<A; a++) {
39     // Blocks of 4 rows
40     for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) {
41       for (CeedInt c=0; c<(C/CC)*CC; c+=CC) {
42         __m256d vv[JJ][CC/4]; // Output tile to be held in registers
43         for (CeedInt jj=0; jj<JJ; jj++)
44           for (CeedInt cc=0; cc<CC/4; cc++)
45             vv[jj][cc] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c+cc*4]);
46 
47         for (CeedInt b=0; b<B; b++) {
48           for (CeedInt jj=0; jj<JJ; jj++) { // unroll
49             __m256d tqv = _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1]);
50             for (CeedInt cc=0; cc<CC/4; cc++) // unroll
51               fmadd(vv[jj][cc], tqv, _mm256_loadu_pd(&u[(a*B+b)*C+c+cc*4]));
52           }
53         }
54         for (CeedInt jj=0; jj<JJ; jj++)
55           for (CeedInt cc=0; cc<CC/4; cc++)
56             _mm256_storeu_pd(&v[(a*J+j+jj)*C+c+cc*4], vv[jj][cc]);
57       }
58     }
59     // Remainder of rows
60     CeedInt j=(J/JJ)*JJ;
61     if (j < J) {
62       for (CeedInt c=0; c<(C/CC)*CC; c+=CC) {
63         __m256d vv[JJ][CC/4]; // Output tile to be held in registers
64         for (CeedInt jj=0; jj<J-j; jj++)
65           for (CeedInt cc=0; cc<CC/4; cc++)
66             vv[jj][cc] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c+cc*4]);
67 
68         for (CeedInt b=0; b<B; b++) {
69           for (CeedInt jj=0; jj<J-j; jj++) { // doesn't unroll
70             __m256d tqv = _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1]);
71             for (CeedInt cc=0; cc<CC/4; cc++) // unroll
72               fmadd(vv[jj][cc], tqv, _mm256_loadu_pd(&u[(a*B+b)*C+c+cc*4]));
73           }
74         }
75         for (CeedInt jj=0; jj<J-j; jj++)
76           for (CeedInt cc=0; cc<CC/4; cc++)
77             _mm256_storeu_pd(&v[(a*J+j+jj)*C+c+cc*4], vv[jj][cc]);
78       }
79     }
80   }
81   return 0;
82 }
83 
84 //------------------------------------------------------------------------------
85 // Serial Tensor Contract Remainder
86 //------------------------------------------------------------------------------
87 static inline int CeedTensorContract_Avx_Remainder(CeedTensorContract contract,
88     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
89     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
90     CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
91   CeedInt tstride0 = B, tstride1 = 1;
92   if (tmode == CEED_TRANSPOSE) {
93     tstride0 = 1; tstride1 = J;
94   }
95 
96   CeedInt Jbreak = J%JJ ? (J/JJ)*JJ : (J/JJ-1)*JJ;
97   for (CeedInt a=0; a<A; a++) {
98     // Blocks of 4 columns
99     for (CeedInt c = (C/CC)*CC; c<C; c+=4) {
100       // Blocks of 4 rows
101       for (CeedInt j=0; j<Jbreak; j+=JJ) {
102         __m256d vv[JJ]; // Output tile to be held in registers
103         for (CeedInt jj=0; jj<JJ; jj++)
104           vv[jj] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c]);
105 
106         for (CeedInt b=0; b<B; b++) {
107           __m256d tqu;
108           if (C-c == 1)
109             tqu = _mm256_set_pd(0.0, 0.0, 0.0, u[(a*B+b)*C+c+0]);
110           else if (C-c == 2)
111             tqu = _mm256_set_pd(0.0, 0.0, u[(a*B+b)*C+c+1],
112                                 u[(a*B+b)*C+c+0]);
113           else if (C-c == 3)
114             tqu = _mm256_set_pd(0.0, u[(a*B+b)*C+c+2], u[(a*B+b)*C+c+1],
115                                 u[(a*B+b)*C+c+0]);
116           else
117             tqu = _mm256_loadu_pd(&u[(a*B+b)*C+c]);
118           for (CeedInt jj=0; jj<JJ; jj++) // unroll
119             fmadd(vv[jj], tqu, _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1]));
120         }
121         for (CeedInt jj=0; jj<JJ; jj++)
122           _mm256_storeu_pd(&v[(a*J+j+jj)*C+c], vv[jj]);
123       }
124     }
125     // Remainder of rows, all columns
126     for (CeedInt j=Jbreak; j<J; j++)
127       for (CeedInt b=0; b<B; b++) {
128         CeedScalar tq = t[j*tstride0 + b*tstride1];
129         for (CeedInt c=(C/CC)*CC; c<C; c++)
130           v[(a*J+j)*C+c] += tq * u[(a*B+b)*C+c];
131       }
132   }
133   return 0;
134 }
135 
136 //------------------------------------------------------------------------------
137 // Serial Tensor Contract C=1
138 //------------------------------------------------------------------------------
139 static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract,
140     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
141     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
142     CeedScalar *restrict v, const CeedInt AA, const CeedInt JJ) {
143   CeedInt tstride0 = B, tstride1 = 1;
144   if (tmode == CEED_TRANSPOSE) {
145     tstride0 = 1; tstride1 = J;
146   }
147 
148   // Blocks of 4 rows
149   for (CeedInt a=0; a<(A/AA)*AA; a+=AA) {
150     for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) {
151       __m256d vv[AA][JJ/4]; // Output tile to be held in registers
152       for (CeedInt aa=0; aa<AA; aa++)
153         for (CeedInt jj=0; jj<JJ/4; jj++)
154           vv[aa][jj] = _mm256_loadu_pd(&v[(a+aa)*J+j+jj*4]);
155 
156       for (CeedInt b=0; b<B; b++) {
157         for (CeedInt jj=0; jj<JJ/4; jj++) { // unroll
158           __m256d tqv = _mm256_set_pd(t[(j+jj*4+3)*tstride0 + b*tstride1],
159                                       t[(j+jj*4+2)*tstride0 + b*tstride1],
160                                       t[(j+jj*4+1)*tstride0 + b*tstride1],
161                                       t[(j+jj*4+0)*tstride0 + b*tstride1]);
162           for (CeedInt aa=0; aa<AA; aa++) // unroll
163             fmadd(vv[aa][jj], tqv, _mm256_set1_pd(u[(a+aa)*B+b]));
164         }
165       }
166       for (CeedInt aa=0; aa<AA; aa++)
167         for (CeedInt jj=0; jj<JJ/4; jj++)
168           _mm256_storeu_pd(&v[(a+aa)*J+j+jj*4], vv[aa][jj]);
169     }
170   }
171   // Remainder of rows
172   CeedInt a=(A/AA)*AA;
173   for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) {
174     __m256d vv[AA][JJ/4]; // Output tile to be held in registers
175     for (CeedInt aa=0; aa<A-a; aa++)
176       for (CeedInt jj=0; jj<JJ/4; jj++)
177         vv[aa][jj] = _mm256_loadu_pd(&v[(a+aa)*J+j+jj*4]);
178 
179     for (CeedInt b=0; b<B; b++) {
180       for (CeedInt jj=0; jj<JJ/4; jj++) { // unroll
181         __m256d tqv = _mm256_set_pd(t[(j+jj*4+3)*tstride0 + b*tstride1],
182                                     t[(j+jj*4+2)*tstride0 + b*tstride1],
183                                     t[(j+jj*4+1)*tstride0 + b*tstride1],
184                                     t[(j+jj*4+0)*tstride0 + b*tstride1]);
185         for (CeedInt aa=0; aa<A-a; aa++) // unroll
186           fmadd(vv[aa][jj], tqv, _mm256_set1_pd(u[(a+aa)*B+b]));
187       }
188     }
189     for (CeedInt aa=0; aa<A-a; aa++)
190       for (CeedInt jj=0; jj<JJ/4; jj++)
191         _mm256_storeu_pd(&v[(a+aa)*J+j+jj*4], vv[aa][jj]);
192   }
193   // Column remainder
194   CeedInt Abreak = A%AA ? (A/AA)*AA : (A/AA-1)*AA;
195   // Blocks of 4 columns
196   for (CeedInt j = (J/JJ)*JJ; j<J; j+=4) {
197     // Blocks of 4 rows
198     for (CeedInt a=0; a<Abreak; a+=AA) {
199       __m256d vv[AA]; // Output tile to be held in registers
200       for (CeedInt aa=0; aa<AA; aa++)
201         vv[aa] = _mm256_loadu_pd(&v[(a+aa)*J+j]);
202 
203       for (CeedInt b=0; b<B; b++) {
204         __m256d tqv;
205         if (J-j == 1)
206           tqv = _mm256_set_pd(0.0, 0.0, 0.0, t[(j+0)*tstride0 + b*tstride1]);
207         else if (J-j == 2)
208           tqv = _mm256_set_pd(0.0, 0.0, t[(j+1)*tstride0 + b*tstride1],
209                               t[(j+0)*tstride0 + b*tstride1]);
210         else if (J-3 == j)
211           tqv = _mm256_set_pd(0.0, t[(j+2)*tstride0 + b*tstride1],
212                               t[(j+1)*tstride0 + b*tstride1],
213                               t[(j+0)*tstride0 + b*tstride1]);
214         else
215           tqv = _mm256_set_pd(t[(j+3)*tstride0 + b*tstride1],
216                               t[(j+2)*tstride0 + b*tstride1],
217                               t[(j+1)*tstride0 + b*tstride1],
218                               t[(j+0)*tstride0 + b*tstride1]);
219         for (CeedInt aa=0; aa<AA; aa++) // unroll
220           fmadd(vv[aa], tqv, _mm256_set1_pd(u[(a+aa)*B+b]));
221       }
222       for (CeedInt aa=0; aa<AA; aa++)
223         _mm256_storeu_pd(&v[(a+aa)*J+j], vv[aa]);
224     }
225   }
226   // Remainder of rows, all columns
227   for (CeedInt b=0; b<B; b++) {
228     for (CeedInt j=(J/JJ)*JJ; j<J; j++) {
229       CeedScalar tq = t[j*tstride0 + b*tstride1];
230       for (CeedInt a=Abreak; a<A; a++)
231         v[a*J+j] += tq * u[a*B+b];
232     }
233   }
234   return 0;
235 }
236 
237 //------------------------------------------------------------------------------
238 // Tensor Contract - Common Sizes
239 //------------------------------------------------------------------------------
240 static int CeedTensorContract_Avx_Blocked_4_8(CeedTensorContract contract,
241     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
242     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
243     CeedScalar *restrict v) {
244   return CeedTensorContract_Avx_Blocked(contract, A, B, C, J, t, tmode, Add, u,
245                                         v, 4, 8);
246 }
247 static int CeedTensorContract_Avx_Remainder_8_8(CeedTensorContract contract,
248     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
249     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
250     CeedScalar *restrict v) {
251   return CeedTensorContract_Avx_Remainder(contract, A, B, C, J, t, tmode, Add,
252                                           u, v, 8, 8);
253 }
254 static int CeedTensorContract_Avx_Single_4_8(CeedTensorContract contract,
255     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
256     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
257     CeedScalar *restrict v) {
258   return CeedTensorContract_Avx_Single(contract, A, B, C, J, t, tmode, Add, u,
259                                        v, 4, 8);
260 }
261 
262 //------------------------------------------------------------------------------
263 // Tensor Contract Apply
264 //------------------------------------------------------------------------------
265 static int CeedTensorContractApply_Avx(CeedTensorContract contract, CeedInt A,
266                                        CeedInt B, CeedInt C, CeedInt J,
267                                        const CeedScalar *restrict t,
268                                        CeedTransposeMode tmode,
269                                        const CeedInt Add,
270                                        const CeedScalar *restrict u,
271                                        CeedScalar *restrict v) {
272   const CeedInt blksize = 8;
273 
274   if (!Add)
275     for (CeedInt q=0; q<A*J*C; q++)
276       v[q] = (CeedScalar) 0.0;
277 
278   if (C == 1) {
279     // Serial C=1 Case
280     CeedTensorContract_Avx_Single_4_8(contract, A, B, C, J, t, tmode, true, u,
281                                       v);
282   } else {
283     // Blocks of 8 columns
284     if (C >= blksize)
285       CeedTensorContract_Avx_Blocked_4_8(contract, A, B, C, J, t, tmode, true,
286                                          u, v);
287     // Remainder of columns
288     if (C % blksize)
289       CeedTensorContract_Avx_Remainder_8_8(contract, A, B, C, J, t, tmode, true,
290                                            u, v);
291   }
292 
293   return 0;
294 }
295 
296 //------------------------------------------------------------------------------
297 // Tensor Contract Destroy
298 //------------------------------------------------------------------------------
299 static int CeedTensorContractDestroy_Avx(CeedTensorContract contract) {
300   return 0;
301 }
302 
303 //------------------------------------------------------------------------------
304 // Tensor Contract Create
305 //------------------------------------------------------------------------------
306 int CeedTensorContractCreate_Avx(CeedBasis basis, CeedTensorContract contract) {
307   int ierr;
308   Ceed ceed;
309   ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChk(ierr);
310 
311   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply",
312                                 CeedTensorContractApply_Avx); CeedChk(ierr);
313   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy",
314                                 CeedTensorContractDestroy_Avx); CeedChk(ierr);
315 
316   return 0;
317 }
318 //------------------------------------------------------------------------------
319