xref: /libCEED/backends/avx/ceed-avx-tensor.c (revision 0219ea01e2c00bd70a330a05b50ef0218d6ddcb0)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include "ceed-avx.h"
18 
19 // Blocked Tensor Contact
20 static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract,
21     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
22     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
23     CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
24   CeedInt tstride0 = B, tstride1 = 1;
25   if (tmode == CEED_TRANSPOSE) {
26     tstride0 = 1; tstride1 = J;
27   }
28 
29   if (!Add)
30     for (CeedInt q=0; q<A*J*C; q++)
31       v[q] = (CeedScalar) 0.0;
32 
33   for (CeedInt a=0; a<A; a++) {
34     // Blocks of 4 rows
35     for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) {
36       for (CeedInt c=0; c<(C/CC)*CC; c+=CC) {
37         __m256d vv[JJ][CC/4]; // Output tile to be held in registers
38         for (CeedInt jj=0; jj<JJ; jj++)
39           for (CeedInt cc=0; cc<CC/4; cc++)
40             vv[jj][cc] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c+cc*4]);
41 
42         for (CeedInt b=0; b<B; b++) {
43           for (CeedInt jj=0; jj<JJ; jj++) { // unroll
44             __m256d tqv = _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1]);
45             for (CeedInt cc=0; cc<CC/4; cc++) // unroll
46               vv[jj][cc] += _mm256_mul_pd(tqv,
47                                           _mm256_loadu_pd(&u[(a*B+b)*C+c+cc*4]));
48           }
49         }
50         for (CeedInt jj=0; jj<JJ; jj++)
51           for (CeedInt cc=0; cc<CC/4; cc++)
52             _mm256_storeu_pd(&v[(a*J+j+jj)*C+c+cc*4], vv[jj][cc]);
53       }
54     }
55     // Remainder of rows
56     CeedInt j=(J/JJ)*JJ;
57     if (j < J) {
58       for (CeedInt c=0; c<(C/CC)*CC; c+=CC) {
59         __m256d vv[JJ][CC/4]; // Output tile to be held in registers
60         for (CeedInt jj=0; jj<J-j; jj++)
61           for (CeedInt cc=0; cc<CC/4; cc++)
62             vv[jj][cc] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c+cc*4]);
63 
64         for (CeedInt b=0; b<B; b++) {
65           for (CeedInt jj=0; jj<J-j; jj++) { // doesn't unroll
66             __m256d tqv = _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1]);
67             for (CeedInt cc=0; cc<CC/4; cc++) // unroll
68               vv[jj][cc] += _mm256_mul_pd(tqv,
69                                           _mm256_loadu_pd(&u[(a*B+b)*C+c+cc*4]));
70           }
71         }
72         for (CeedInt jj=0; jj<J-j; jj++)
73           for (CeedInt cc=0; cc<CC/4; cc++)
74             _mm256_storeu_pd(&v[(a*J+j+jj)*C+c+cc*4], vv[jj][cc]);
75       }
76     }
77   }
78   return 0;
79 }
80 
81 // Serial Tensor Contract Remainder
82 static inline int CeedTensorContract_Avx_Remainder(CeedTensorContract contract,
83     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
84     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
85     CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
86   CeedInt tstride0 = B, tstride1 = 1;
87   if (tmode == CEED_TRANSPOSE) {
88     tstride0 = 1; tstride1 = J;
89   }
90 
91   CeedInt Jbreak = J%JJ ? (J/JJ)*JJ : (J/JJ-1)*JJ;
92   for (CeedInt a=0; a<A; a++) {
93     // Blocks of 4 columns
94     for (CeedInt c = (C/CC)*CC; c<C; c+=4) {
95       // Blocks of 4 rows
96       for (CeedInt j=0; j<Jbreak; j+=JJ) {
97         __m256d vv[JJ]; // Output tile to be held in registers
98         for (CeedInt jj=0; jj<JJ; jj++)
99           vv[jj] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c]);
100 
101         for (CeedInt b=0; b<B; b++) {
102           __m256d tqu;
103           if (C-c == 1)
104             tqu = _mm256_set_pd(0.0, 0.0, 0.0, u[(a*B+b)*C+c+0]);
105           else if (C-c == 2)
106             tqu = _mm256_set_pd(0.0, 0.0, u[(a*B+b)*C+c+1],
107                                 u[(a*B+b)*C+c+0]);
108           else if (C-c == 3)
109             tqu = _mm256_set_pd(0.0, u[(a*B+b)*C+c+2], u[(a*B+b)*C+c+1],
110                                 u[(a*B+b)*C+c+0]);
111           else
112             tqu = _mm256_loadu_pd(&u[(a*B+b)*C+c]);
113           for (CeedInt jj=0; jj<JJ; jj++) // unroll
114             vv[jj] += _mm256_mul_pd(tqu,
115                                     _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1]));
116         }
117         for (CeedInt jj=0; jj<JJ; jj++)
118           _mm256_storeu_pd(&v[(a*J+j+jj)*C+c], vv[jj]);
119       }
120     }
121     // Remainder of rows, all columns
122     for (CeedInt j=Jbreak; j<J; j++)
123       for (CeedInt b=0; b<B; b++) {
124         CeedScalar tq = t[j*tstride0 + b*tstride1];
125         for (CeedInt c=(C/CC)*CC; c<C; c++)
126           v[(a*J+j)*C+c] += tq * u[(a*B+b)*C+c];
127       }
128   }
129   return 0;
130 }
131 
132 // Serial Tensor Contract C=1 Case
133 static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract,
134     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
135     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
136     CeedScalar *restrict v, const CeedInt AA, const CeedInt JJ) {
137   CeedInt tstride0 = B, tstride1 = 1;
138   if (tmode == CEED_TRANSPOSE) {
139     tstride0 = 1; tstride1 = J;
140   }
141 
142   if (!Add)
143     for (CeedInt q=0; q<A*J*C; q++)
144       v[q] = (CeedScalar) 0.0;
145 
146   // Blocks of 4 rows
147   for (CeedInt a=0; a<(A/AA)*AA; a+=AA) {
148     for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) {
149       __m256d vv[AA][JJ/4]; // Output tile to be held in registers
150       for (CeedInt aa=0; aa<AA; aa++)
151         for (CeedInt jj=0; jj<JJ/4; jj++)
152           vv[aa][jj] = _mm256_loadu_pd(&v[(a+aa)*J+j+jj*4]);
153 
154       for (CeedInt b=0; b<B; b++) {
155         for (CeedInt jj=0; jj<JJ/4; jj++) { // unroll
156           __m256d tqv = _mm256_set_pd(t[(j+jj*4+3)*tstride0 + b*tstride1],
157                                       t[(j+jj*4+2)*tstride0 + b*tstride1],
158                                       t[(j+jj*4+1)*tstride0 + b*tstride1],
159                                       t[(j+jj*4+0)*tstride0 + b*tstride1]);
160           for (CeedInt aa=0; aa<AA; aa++) // unroll
161             vv[aa][jj] += _mm256_mul_pd(tqv, _mm256_set1_pd(u[(a+aa)*B+b]));
162         }
163       }
164       for (CeedInt aa=0; aa<AA; aa++)
165         for (CeedInt jj=0; jj<JJ/4; jj++)
166           _mm256_storeu_pd(&v[(a+aa)*J+j+jj*4], vv[aa][jj]);
167     }
168   }
169   // Remainder of rows
170   CeedInt a=(A/AA)*AA;
171   for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) {
172     __m256d vv[AA][JJ/4]; // Output tile to be held in registers
173     for (CeedInt aa=0; aa<A-a; aa++)
174       for (CeedInt jj=0; jj<JJ/4; jj++)
175         vv[aa][jj] = _mm256_loadu_pd(&v[(a+aa)*J+j+jj*4]);
176 
177     for (CeedInt b=0; b<B; b++) {
178       for (CeedInt jj=0; jj<JJ/4; jj++) { // unroll
179         __m256d tqv = _mm256_set_pd(t[(j+jj*4+3)*tstride0 + b*tstride1],
180                                     t[(j+jj*4+2)*tstride0 + b*tstride1],
181                                     t[(j+jj*4+1)*tstride0 + b*tstride1],
182                                     t[(j+jj*4+0)*tstride0 + b*tstride1]);
183         for (CeedInt aa=0; aa<A-a; aa++) // unroll
184           vv[aa][jj] += _mm256_mul_pd(tqv, _mm256_set1_pd(u[(a+aa)*B+b]));
185       }
186     }
187     for (CeedInt aa=0; aa<A-a; aa++)
188       for (CeedInt jj=0; jj<JJ/4; jj++)
189         _mm256_storeu_pd(&v[(a+aa)*J+j+jj*4], vv[aa][jj]);
190   }
191   // Column remainder
192   CeedInt Abreak = A%AA ? (A/AA)*AA : (A/AA-1)*AA;
193   // Blocks of 4 columns
194   for (CeedInt j = (J/JJ)*JJ; j<J; j+=4) {
195     // Blocks of 4 rows
196     for (CeedInt a=0; a<Abreak; a+=AA) {
197       __m256d vv[AA]; // Output tile to be held in registers
198       for (CeedInt aa=0; aa<AA; aa++)
199         vv[aa] = _mm256_loadu_pd(&v[(a+aa)*J+j]);
200 
201       for (CeedInt b=0; b<B; b++) {
202         __m256d tqv;
203         if (J-j == 1)
204           tqv = _mm256_set_pd(0.0, 0.0, 0.0, t[(j+0)*tstride0 + b*tstride1]);
205         else if (J-j == 2)
206           tqv = _mm256_set_pd(0.0, 0.0, t[(j+1)*tstride0 + b*tstride1],
207                               t[(j+0)*tstride0 + b*tstride1]);
208         else if (J-3 == j)
209           tqv = _mm256_set_pd(0.0, t[(j+2)*tstride0 + b*tstride1],
210                               t[(j+1)*tstride0 + b*tstride1],
211                               t[(j+0)*tstride0 + b*tstride1]);
212         else
213           tqv = _mm256_set_pd(t[(j+3)*tstride0 + b*tstride1],
214                               t[(j+2)*tstride0 + b*tstride1],
215                               t[(j+1)*tstride0 + b*tstride1],
216                               t[(j+0)*tstride0 + b*tstride1]);
217         for (CeedInt aa=0; aa<AA; aa++) // unroll
218           vv[aa] += _mm256_mul_pd(tqv, _mm256_set1_pd(u[(a+aa)*B+b]));
219       }
220       for (CeedInt aa=0; aa<AA; aa++)
221         _mm256_storeu_pd(&v[(a+aa)*J+j], vv[aa]);
222     }
223   }
224   // Remainder of rows, all columns
225   for (CeedInt b=0; b<B; b++) {
226     for (CeedInt j=(J/JJ)*JJ; j<J; j++) {
227       CeedScalar tq = t[j*tstride0 + b*tstride1];
228       for (CeedInt a=Abreak; a<A; a++)
229         v[a*J+j] += tq * u[a*B+b];
230     }
231   }
232   return 0;
233 }
234 
235 // Specific Variants
236 static int CeedTensorContract_Avx_Blocked_4_8(CeedTensorContract contract,
237     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
238     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
239     CeedScalar *restrict v) {
240   return CeedTensorContract_Avx_Blocked(contract, A, B, C, J, t, tmode, Add, u,
241                                         v, 4, 8);
242 }
243 static int CeedTensorContract_Avx_Remainder_8_8(CeedTensorContract contract,
244     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
245     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
246     CeedScalar *restrict v) {
247   return CeedTensorContract_Avx_Remainder(contract, A, B, C, J, t, tmode, Add,
248                                           u, v, 8, 8);
249 }
250 static int CeedTensorContract_Avx_Single_4_8(CeedTensorContract contract,
251     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
252     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
253     CeedScalar *restrict v) {
254   return CeedTensorContract_Avx_Single(contract, A, B, C, J, t, tmode, Add, u,
255                                        v, 4, 8);
256 }
257 
258 // Switch for Tensor Contract
259 static int CeedTensorContractApply_Avx(CeedTensorContract contract, CeedInt A,
260                                        CeedInt B, CeedInt C, CeedInt J,
261                                        const CeedScalar *restrict t,
262                                        CeedTransposeMode tmode,
263                                        const CeedInt Add,
264                                        const CeedScalar *restrict u,
265                                        CeedScalar *restrict v) {
266   const CeedInt blksize = 8;
267 
268   if (!Add)
269     for (CeedInt q=0; q<A*J*C; q++)
270       v[q] = (CeedScalar) 0.0;
271 
272   if (C == 1) {
273     // Serial C=1 Case
274     CeedTensorContract_Avx_Single_4_8(contract, A, B, C, J, t, tmode, true, u,
275                                       v);
276   } else {
277     // Blocks of 8 columns
278     if (C >= blksize)
279       CeedTensorContract_Avx_Blocked_4_8(contract, A, B, C, J, t, tmode, true,
280                                          u, v);
281     // Remainder of columns
282     if (C % blksize)
283       CeedTensorContract_Avx_Remainder_8_8(contract, A, B, C, J, t, tmode, true,
284                                            u, v);
285   }
286 
287   return 0;
288 }
289 
290 static int CeedTensorContractDestroy_Avx(CeedTensorContract contract) {
291   return 0;
292 }
293 
294 int CeedTensorContractCreate_Avx(CeedBasis basis, CeedTensorContract contract) {
295   int ierr;
296   Ceed ceed;
297   ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChk(ierr);
298 
299   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply",
300                                 CeedTensorContractApply_Avx); CeedChk(ierr);
301   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy",
302                                 CeedTensorContractDestroy_Avx); CeedChk(ierr);
303 
304   return 0;
305 }
306