xref: /libCEED/backends/avx/ceed-avx-tensor.c (revision b3d368a0e19d4c2fd0dbe42e8214dac7e42dd106)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include "ceed-avx.h"
18 
19 //------------------------------------------------------------------------------
20 // Blocked Tensor Contract
21 //------------------------------------------------------------------------------
22 static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract,
23     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
24     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
25     CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
26   CeedInt tstride0 = B, tstride1 = 1;
27   if (tmode == CEED_TRANSPOSE) {
28     tstride0 = 1; tstride1 = J;
29   }
30 
31   for (CeedInt a=0; a<A; a++) {
32     // Blocks of 4 rows
33     for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) {
34       for (CeedInt c=0; c<(C/CC)*CC; c+=CC) {
35         __m256d vv[JJ][CC/4]; // Output tile to be held in registers
36         for (CeedInt jj=0; jj<JJ; jj++)
37           for (CeedInt cc=0; cc<CC/4; cc++)
38             vv[jj][cc] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c+cc*4]);
39 
40         for (CeedInt b=0; b<B; b++) {
41           for (CeedInt jj=0; jj<JJ; jj++) { // unroll
42             __m256d tqv = _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1]);
43             for (CeedInt cc=0; cc<CC/4; cc++) // unroll
44               vv[jj][cc] += _mm256_mul_pd(tqv,
45                                           _mm256_loadu_pd(&u[(a*B+b)*C+c+cc*4]));
46           }
47         }
48         for (CeedInt jj=0; jj<JJ; jj++)
49           for (CeedInt cc=0; cc<CC/4; cc++)
50             _mm256_storeu_pd(&v[(a*J+j+jj)*C+c+cc*4], vv[jj][cc]);
51       }
52     }
53     // Remainder of rows
54     CeedInt j=(J/JJ)*JJ;
55     if (j < J) {
56       for (CeedInt c=0; c<(C/CC)*CC; c+=CC) {
57         __m256d vv[JJ][CC/4]; // Output tile to be held in registers
58         for (CeedInt jj=0; jj<J-j; jj++)
59           for (CeedInt cc=0; cc<CC/4; cc++)
60             vv[jj][cc] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c+cc*4]);
61 
62         for (CeedInt b=0; b<B; b++) {
63           for (CeedInt jj=0; jj<J-j; jj++) { // doesn't unroll
64             __m256d tqv = _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1]);
65             for (CeedInt cc=0; cc<CC/4; cc++) // unroll
66               vv[jj][cc] += _mm256_mul_pd(tqv,
67                                           _mm256_loadu_pd(&u[(a*B+b)*C+c+cc*4]));
68           }
69         }
70         for (CeedInt jj=0; jj<J-j; jj++)
71           for (CeedInt cc=0; cc<CC/4; cc++)
72             _mm256_storeu_pd(&v[(a*J+j+jj)*C+c+cc*4], vv[jj][cc]);
73       }
74     }
75   }
76   return 0;
77 }
78 
79 //------------------------------------------------------------------------------
80 // Serial Tensor Contract Remainder
81 //------------------------------------------------------------------------------
82 static inline int CeedTensorContract_Avx_Remainder(CeedTensorContract contract,
83     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
84     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
85     CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
86   CeedInt tstride0 = B, tstride1 = 1;
87   if (tmode == CEED_TRANSPOSE) {
88     tstride0 = 1; tstride1 = J;
89   }
90 
91   CeedInt Jbreak = J%JJ ? (J/JJ)*JJ : (J/JJ-1)*JJ;
92   for (CeedInt a=0; a<A; a++) {
93     // Blocks of 4 columns
94     for (CeedInt c = (C/CC)*CC; c<C; c+=4) {
95       // Blocks of 4 rows
96       for (CeedInt j=0; j<Jbreak; j+=JJ) {
97         __m256d vv[JJ]; // Output tile to be held in registers
98         for (CeedInt jj=0; jj<JJ; jj++)
99           vv[jj] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c]);
100 
101         for (CeedInt b=0; b<B; b++) {
102           __m256d tqu;
103           if (C-c == 1)
104             tqu = _mm256_set_pd(0.0, 0.0, 0.0, u[(a*B+b)*C+c+0]);
105           else if (C-c == 2)
106             tqu = _mm256_set_pd(0.0, 0.0, u[(a*B+b)*C+c+1],
107                                 u[(a*B+b)*C+c+0]);
108           else if (C-c == 3)
109             tqu = _mm256_set_pd(0.0, u[(a*B+b)*C+c+2], u[(a*B+b)*C+c+1],
110                                 u[(a*B+b)*C+c+0]);
111           else
112             tqu = _mm256_loadu_pd(&u[(a*B+b)*C+c]);
113           for (CeedInt jj=0; jj<JJ; jj++) // unroll
114             vv[jj] += _mm256_mul_pd(tqu,
115                                     _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1]));
116         }
117         for (CeedInt jj=0; jj<JJ; jj++)
118           _mm256_storeu_pd(&v[(a*J+j+jj)*C+c], vv[jj]);
119       }
120     }
121     // Remainder of rows, all columns
122     for (CeedInt j=Jbreak; j<J; j++)
123       for (CeedInt b=0; b<B; b++) {
124         CeedScalar tq = t[j*tstride0 + b*tstride1];
125         for (CeedInt c=(C/CC)*CC; c<C; c++)
126           v[(a*J+j)*C+c] += tq * u[(a*B+b)*C+c];
127       }
128   }
129   return 0;
130 }
131 
132 //------------------------------------------------------------------------------
133 // Serial Tensor Contract C=1
134 //------------------------------------------------------------------------------
135 static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract,
136     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
137     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
138     CeedScalar *restrict v, const CeedInt AA, const CeedInt JJ) {
139   CeedInt tstride0 = B, tstride1 = 1;
140   if (tmode == CEED_TRANSPOSE) {
141     tstride0 = 1; tstride1 = J;
142   }
143 
144   // Blocks of 4 rows
145   for (CeedInt a=0; a<(A/AA)*AA; a+=AA) {
146     for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) {
147       __m256d vv[AA][JJ/4]; // Output tile to be held in registers
148       for (CeedInt aa=0; aa<AA; aa++)
149         for (CeedInt jj=0; jj<JJ/4; jj++)
150           vv[aa][jj] = _mm256_loadu_pd(&v[(a+aa)*J+j+jj*4]);
151 
152       for (CeedInt b=0; b<B; b++) {
153         for (CeedInt jj=0; jj<JJ/4; jj++) { // unroll
154           __m256d tqv = _mm256_set_pd(t[(j+jj*4+3)*tstride0 + b*tstride1],
155                                       t[(j+jj*4+2)*tstride0 + b*tstride1],
156                                       t[(j+jj*4+1)*tstride0 + b*tstride1],
157                                       t[(j+jj*4+0)*tstride0 + b*tstride1]);
158           for (CeedInt aa=0; aa<AA; aa++) // unroll
159             vv[aa][jj] += _mm256_mul_pd(tqv, _mm256_set1_pd(u[(a+aa)*B+b]));
160         }
161       }
162       for (CeedInt aa=0; aa<AA; aa++)
163         for (CeedInt jj=0; jj<JJ/4; jj++)
164           _mm256_storeu_pd(&v[(a+aa)*J+j+jj*4], vv[aa][jj]);
165     }
166   }
167   // Remainder of rows
168   CeedInt a=(A/AA)*AA;
169   for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) {
170     __m256d vv[AA][JJ/4]; // Output tile to be held in registers
171     for (CeedInt aa=0; aa<A-a; aa++)
172       for (CeedInt jj=0; jj<JJ/4; jj++)
173         vv[aa][jj] = _mm256_loadu_pd(&v[(a+aa)*J+j+jj*4]);
174 
175     for (CeedInt b=0; b<B; b++) {
176       for (CeedInt jj=0; jj<JJ/4; jj++) { // unroll
177         __m256d tqv = _mm256_set_pd(t[(j+jj*4+3)*tstride0 + b*tstride1],
178                                     t[(j+jj*4+2)*tstride0 + b*tstride1],
179                                     t[(j+jj*4+1)*tstride0 + b*tstride1],
180                                     t[(j+jj*4+0)*tstride0 + b*tstride1]);
181         for (CeedInt aa=0; aa<A-a; aa++) // unroll
182           vv[aa][jj] += _mm256_mul_pd(tqv, _mm256_set1_pd(u[(a+aa)*B+b]));
183       }
184     }
185     for (CeedInt aa=0; aa<A-a; aa++)
186       for (CeedInt jj=0; jj<JJ/4; jj++)
187         _mm256_storeu_pd(&v[(a+aa)*J+j+jj*4], vv[aa][jj]);
188   }
189   // Column remainder
190   CeedInt Abreak = A%AA ? (A/AA)*AA : (A/AA-1)*AA;
191   // Blocks of 4 columns
192   for (CeedInt j = (J/JJ)*JJ; j<J; j+=4) {
193     // Blocks of 4 rows
194     for (CeedInt a=0; a<Abreak; a+=AA) {
195       __m256d vv[AA]; // Output tile to be held in registers
196       for (CeedInt aa=0; aa<AA; aa++)
197         vv[aa] = _mm256_loadu_pd(&v[(a+aa)*J+j]);
198 
199       for (CeedInt b=0; b<B; b++) {
200         __m256d tqv;
201         if (J-j == 1)
202           tqv = _mm256_set_pd(0.0, 0.0, 0.0, t[(j+0)*tstride0 + b*tstride1]);
203         else if (J-j == 2)
204           tqv = _mm256_set_pd(0.0, 0.0, t[(j+1)*tstride0 + b*tstride1],
205                               t[(j+0)*tstride0 + b*tstride1]);
206         else if (J-3 == j)
207           tqv = _mm256_set_pd(0.0, t[(j+2)*tstride0 + b*tstride1],
208                               t[(j+1)*tstride0 + b*tstride1],
209                               t[(j+0)*tstride0 + b*tstride1]);
210         else
211           tqv = _mm256_set_pd(t[(j+3)*tstride0 + b*tstride1],
212                               t[(j+2)*tstride0 + b*tstride1],
213                               t[(j+1)*tstride0 + b*tstride1],
214                               t[(j+0)*tstride0 + b*tstride1]);
215         for (CeedInt aa=0; aa<AA; aa++) // unroll
216           vv[aa] += _mm256_mul_pd(tqv, _mm256_set1_pd(u[(a+aa)*B+b]));
217       }
218       for (CeedInt aa=0; aa<AA; aa++)
219         _mm256_storeu_pd(&v[(a+aa)*J+j], vv[aa]);
220     }
221   }
222   // Remainder of rows, all columns
223   for (CeedInt b=0; b<B; b++) {
224     for (CeedInt j=(J/JJ)*JJ; j<J; j++) {
225       CeedScalar tq = t[j*tstride0 + b*tstride1];
226       for (CeedInt a=Abreak; a<A; a++)
227         v[a*J+j] += tq * u[a*B+b];
228     }
229   }
230   return 0;
231 }
232 
233 //------------------------------------------------------------------------------
234 // Tensor Contract - Common Sizes
235 //------------------------------------------------------------------------------
236 static int CeedTensorContract_Avx_Blocked_4_8(CeedTensorContract contract,
237     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
238     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
239     CeedScalar *restrict v) {
240   return CeedTensorContract_Avx_Blocked(contract, A, B, C, J, t, tmode, Add, u,
241                                         v, 4, 8);
242 }
243 static int CeedTensorContract_Avx_Remainder_8_8(CeedTensorContract contract,
244     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
245     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
246     CeedScalar *restrict v) {
247   return CeedTensorContract_Avx_Remainder(contract, A, B, C, J, t, tmode, Add,
248                                           u, v, 8, 8);
249 }
250 static int CeedTensorContract_Avx_Single_4_8(CeedTensorContract contract,
251     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
252     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
253     CeedScalar *restrict v) {
254   return CeedTensorContract_Avx_Single(contract, A, B, C, J, t, tmode, Add, u,
255                                        v, 4, 8);
256 }
257 
258 //------------------------------------------------------------------------------
259 // Tensor Contract Apply
260 //------------------------------------------------------------------------------
261 static int CeedTensorContractApply_Avx(CeedTensorContract contract, CeedInt A,
262                                        CeedInt B, CeedInt C, CeedInt J,
263                                        const CeedScalar *restrict t,
264                                        CeedTransposeMode tmode,
265                                        const CeedInt Add,
266                                        const CeedScalar *restrict u,
267                                        CeedScalar *restrict v) {
268   const CeedInt blksize = 8;
269 
270   if (!Add)
271     for (CeedInt q=0; q<A*J*C; q++)
272       v[q] = (CeedScalar) 0.0;
273 
274   if (C == 1) {
275     // Serial C=1 Case
276     CeedTensorContract_Avx_Single_4_8(contract, A, B, C, J, t, tmode, true, u,
277                                       v);
278   } else {
279     // Blocks of 8 columns
280     if (C >= blksize)
281       CeedTensorContract_Avx_Blocked_4_8(contract, A, B, C, J, t, tmode, true,
282                                          u, v);
283     // Remainder of columns
284     if (C % blksize)
285       CeedTensorContract_Avx_Remainder_8_8(contract, A, B, C, J, t, tmode, true,
286                                            u, v);
287   }
288 
289   return 0;
290 }
291 
292 //------------------------------------------------------------------------------
293 // Tensor Contract Destroy
294 //------------------------------------------------------------------------------
295 static int CeedTensorContractDestroy_Avx(CeedTensorContract contract) {
296   return 0;
297 }
298 
299 //------------------------------------------------------------------------------
300 // Tensor Contract Create
301 //------------------------------------------------------------------------------
302 int CeedTensorContractCreate_Avx(CeedBasis basis, CeedTensorContract contract) {
303   int ierr;
304   Ceed ceed;
305   ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChk(ierr);
306 
307   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply",
308                                 CeedTensorContractApply_Avx); CeedChk(ierr);
309   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy",
310                                 CeedTensorContractDestroy_Avx); CeedChk(ierr);
311 
312   return 0;
313 }
314 //------------------------------------------------------------------------------
315