xref: /libCEED/backends/avx/ceed-avx-tensor.c (revision c907536f492a0d8a515e9b59267362cd88b5caa9)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <string.h>
18 #include <immintrin.h>
19 #include "ceed-avx.h"
20 
21 // Blocked Tensor Contact
22 static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract,
23     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
24     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
25     CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
26   CeedInt tstride0 = B, tstride1 = 1;
27   if (tmode == CEED_TRANSPOSE) {
28     tstride0 = 1; tstride1 = J;
29   }
30 
31   if (!Add)
32     for (CeedInt q=0; q<A*J*C; q++)
33       v[q] = (CeedScalar) 0.0;
34 
35   for (CeedInt a=0; a<A; a++) {
36     // Blocks of 4 rows
37     for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) {
38       for (CeedInt c=0; c<(C/CC)*CC; c+=CC) {
39         __m256d vv[JJ][CC/4]; // Output tile to be held in registers
40         for (CeedInt jj=0; jj<JJ; jj++)
41           for (CeedInt cc=0; cc<CC/4; cc++)
42             vv[jj][cc] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c+cc*4]);
43 
44         for (CeedInt b=0; b<B; b++) {
45           for (CeedInt jj=0; jj<JJ; jj++) { // unroll
46             __m256d tqv = _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1]);
47             for (CeedInt cc=0; cc<CC/4; cc++) // unroll
48               vv[jj][cc] += _mm256_mul_pd(tqv,
49                                           _mm256_loadu_pd(&u[(a*B+b)*C+c+cc*4]));
50           }
51         }
52         for (CeedInt jj=0; jj<JJ; jj++)
53           for (CeedInt cc=0; cc<CC/4; cc++)
54             _mm256_storeu_pd(&v[(a*J+j+jj)*C+c+cc*4], vv[jj][cc]);
55       }
56     }
57     // Remainder of rows
58     CeedInt j=(J/JJ)*JJ;
59     if (j < J) {
60       for (CeedInt c=0; c<(C/CC)*CC; c+=CC) {
61         __m256d vv[JJ][CC/4]; // Output tile to be held in registers
62         for (CeedInt jj=0; jj<J-j; jj++)
63           for (CeedInt cc=0; cc<CC/4; cc++)
64             vv[jj][cc] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c+cc*4]);
65 
66         for (CeedInt b=0; b<B; b++) {
67           for (CeedInt jj=0; jj<J-j; jj++) { // doesn't unroll
68             __m256d tqv = _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1]);
69             for (CeedInt cc=0; cc<CC/4; cc++) // unroll
70               vv[jj][cc] += _mm256_mul_pd(tqv,
71                                           _mm256_loadu_pd(&u[(a*B+b)*C+c+cc*4]));
72           }
73         }
74         for (CeedInt jj=0; jj<J-j; jj++)
75           for (CeedInt cc=0; cc<CC/4; cc++)
76             _mm256_storeu_pd(&v[(a*J+j+jj)*C+c+cc*4], vv[jj][cc]);
77       }
78     }
79   }
80   return 0;
81 }
82 
83 // Serial Tensor Contract Remainder
84 static inline int CeedTensorContract_Avx_Remainder(CeedTensorContract contract,
85     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
86     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
87     CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
88   CeedInt tstride0 = B, tstride1 = 1;
89   if (tmode == CEED_TRANSPOSE) {
90     tstride0 = 1; tstride1 = J;
91   }
92 
93   CeedInt Jbreak = J%JJ ? (J/JJ)*JJ : (J/JJ-1)*JJ;
94   for (CeedInt a=0; a<A; a++) {
95     // Blocks of 4 columns
96     for (CeedInt c = (C/CC)*CC; c<C; c+=4) {
97       // Blocks of 4 rows
98       for (CeedInt j=0; j<Jbreak; j+=JJ) {
99         __m256d vv[JJ]; // Output tile to be held in registers
100         for (CeedInt jj=0; jj<JJ; jj++)
101           vv[jj] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c]);
102 
103         for (CeedInt b=0; b<B; b++) {
104           __m256d tqu;
105           if (C-c == 1)
106             tqu = _mm256_set_pd(0.0, 0.0, 0.0, u[(a*B+b)*C+c+0]);
107           else if (C-c == 2)
108             tqu = _mm256_set_pd(0.0, 0.0, u[(a*B+b)*C+c+1],
109                                 u[(a*B+b)*C+c+0]);
110           else if (C-c == 3)
111             tqu = _mm256_set_pd(0.0, u[(a*B+b)*C+c+2], u[(a*B+b)*C+c+1],
112                                 u[(a*B+b)*C+c+0]);
113           else
114             tqu = _mm256_loadu_pd(&u[(a*B+b)*C+c]);
115           for (CeedInt jj=0; jj<JJ; jj++) // unroll
116             vv[jj] += _mm256_mul_pd(tqu,
117                                     _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1]));
118         }
119         for (CeedInt jj=0; jj<JJ; jj++)
120           _mm256_storeu_pd(&v[(a*J+j+jj)*C+c], vv[jj]);
121       }
122     }
123     // Remainder of rows, all columns
124     for (CeedInt j=Jbreak; j<J; j++)
125       for (CeedInt b=0; b<B; b++) {
126         CeedScalar tq = t[j*tstride0 + b*tstride1];
127         for (CeedInt c=(C/CC)*CC; c<C; c++)
128           v[(a*J+j)*C+c] += tq * u[(a*B+b)*C+c];
129       }
130   }
131   return 0;
132 }
133 
134 // Serial Tensor Contract C=1 Case
135 static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract,
136     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
137     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
138     CeedScalar *restrict v, const CeedInt AA, const CeedInt JJ) {
139   CeedInt tstride0 = B, tstride1 = 1;
140   if (tmode == CEED_TRANSPOSE) {
141     tstride0 = 1; tstride1 = J;
142   }
143 
144   if (!Add)
145     for (CeedInt q=0; q<A*J*C; q++)
146       v[q] = (CeedScalar) 0.0;
147 
148   // Blocks of 4 rows
149   for (CeedInt a=0; a<(A/AA)*AA; a+=AA) {
150     for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) {
151       __m256d vv[AA][JJ/4]; // Output tile to be held in registers
152       for (CeedInt aa=0; aa<AA; aa++)
153         for (CeedInt jj=0; jj<JJ/4; jj++)
154           vv[aa][jj] = _mm256_loadu_pd(&v[(a+aa)*J+j+jj*4]);
155 
156       for (CeedInt b=0; b<B; b++) {
157         for (CeedInt jj=0; jj<JJ/4; jj++) { // unroll
158           __m256d tqv = _mm256_set_pd(t[(j+jj*4+3)*tstride0 + b*tstride1],
159                                       t[(j+jj*4+2)*tstride0 + b*tstride1],
160                                       t[(j+jj*4+1)*tstride0 + b*tstride1],
161                                       t[(j+jj*4+0)*tstride0 + b*tstride1]);
162           for (CeedInt aa=0; aa<AA; aa++) // unroll
163             vv[aa][jj] += _mm256_mul_pd(tqv, _mm256_set1_pd(u[(a+aa)*B+b]));
164         }
165       }
166       for (CeedInt aa=0; aa<AA; aa++)
167         for (CeedInt jj=0; jj<JJ/4; jj++)
168           _mm256_storeu_pd(&v[(a+aa)*J+j+jj*4], vv[aa][jj]);
169     }
170   }
171   // Remainder of rows
172   CeedInt a=(A/AA)*AA;
173   for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) {
174     __m256d vv[AA][JJ/4]; // Output tile to be held in registers
175     for (CeedInt aa=0; aa<A-a; aa++)
176       for (CeedInt jj=0; jj<JJ/4; jj++)
177         vv[aa][jj] = _mm256_loadu_pd(&v[(a+aa)*J+j+jj*4]);
178 
179     for (CeedInt b=0; b<B; b++) {
180       for (CeedInt jj=0; jj<JJ/4; jj++) { // unroll
181         __m256d tqv = _mm256_set_pd(t[(j+jj*4+3)*tstride0 + b*tstride1],
182                                     t[(j+jj*4+2)*tstride0 + b*tstride1],
183                                     t[(j+jj*4+1)*tstride0 + b*tstride1],
184                                     t[(j+jj*4+0)*tstride0 + b*tstride1]);
185         for (CeedInt aa=0; aa<A-a; aa++) // unroll
186           vv[aa][jj] += _mm256_mul_pd(tqv, _mm256_set1_pd(u[(a+aa)*B+b]));
187       }
188     }
189     for (CeedInt aa=0; aa<A-a; aa++)
190       for (CeedInt jj=0; jj<JJ/4; jj++)
191         _mm256_storeu_pd(&v[(a+aa)*J+j+jj*4], vv[aa][jj]);
192   }
193   // Column remainder
194   CeedInt Abreak = A%AA ? (A/AA)*AA : (A/AA-1)*AA;
195   // Blocks of 4 columns
196   for (CeedInt j = (J/JJ)*JJ; j<J; j+=4) {
197     // Blocks of 4 rows
198     for (CeedInt a=0; a<Abreak; a+=AA) {
199       __m256d vv[AA]; // Output tile to be held in registers
200       for (CeedInt aa=0; aa<AA; aa++)
201         vv[aa] = _mm256_loadu_pd(&v[(a+aa)*J+j]);
202 
203       for (CeedInt b=0; b<B; b++) {
204         __m256d tqv;
205         if (J-j == 1)
206           tqv = _mm256_set_pd(0.0, 0.0, 0.0, t[(j+0)*tstride0 + b*tstride1]);
207         else if (J-j == 2)
208           tqv = _mm256_set_pd(0.0, 0.0, t[(j+1)*tstride0 + b*tstride1],
209                               t[(j+0)*tstride0 + b*tstride1]);
210         else if (J-3 == j)
211           tqv = _mm256_set_pd(0.0, t[(j+2)*tstride0 + b*tstride1],
212                               t[(j+1)*tstride0 + b*tstride1],
213                               t[(j+0)*tstride0 + b*tstride1]);
214         else
215           tqv = _mm256_set_pd(t[(j+3)*tstride0 + b*tstride1],
216                               t[(j+2)*tstride0 + b*tstride1],
217                               t[(j+1)*tstride0 + b*tstride1],
218                               t[(j+0)*tstride0 + b*tstride1]);
219         for (CeedInt aa=0; aa<AA; aa++) // unroll
220           vv[aa] += _mm256_mul_pd(tqv, _mm256_set1_pd(u[(a+aa)*B+b]));
221       }
222       for (CeedInt aa=0; aa<AA; aa++)
223         _mm256_storeu_pd(&v[(a+aa)*J+j], vv[aa]);
224     }
225   }
226   // Remainder of rows, all columns
227   for (CeedInt b=0; b<B; b++) {
228     for (CeedInt j=(J/JJ)*JJ; j<J; j++) {
229       CeedScalar tq = t[j*tstride0 + b*tstride1];
230       for (CeedInt a=Abreak; a<A; a++)
231         v[a*J+j] += tq * u[a*B+b];
232     }
233   }
234   return 0;
235 }
236 
237 // Specific Variants
238 static int CeedTensorContract_Avx_Blocked_4_8(CeedTensorContract contract,
239     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
240     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
241     CeedScalar *restrict v) {
242   return CeedTensorContract_Avx_Blocked(contract, A, B, C, J, t, tmode, Add, u,
243                                         v, 4, 8);
244 }
245 static int CeedTensorContract_Avx_Remainder_8_8(CeedTensorContract contract,
246     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
247     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
248     CeedScalar *restrict v) {
249   return CeedTensorContract_Avx_Remainder(contract, A, B, C, J, t, tmode, Add,
250                                           u, v, 8, 8);
251 }
252 static int CeedTensorContract_Avx_Single_4_8(CeedTensorContract contract,
253     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
254     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
255     CeedScalar *restrict v) {
256   return CeedTensorContract_Avx_Single(contract, A, B, C, J, t, tmode, Add, u,
257                                        v, 4, 8);
258 }
259 
260 // Switch for Tensor Contract
261 static int CeedTensorContractApply_Avx(CeedTensorContract contract, CeedInt A,
262                                        CeedInt B, CeedInt C, CeedInt J,
263                                        const CeedScalar *restrict t,
264                                        CeedTransposeMode tmode,
265                                        const CeedInt Add,
266                                        const CeedScalar *restrict u,
267                                        CeedScalar *restrict v) {
268   const CeedInt blksize = 8;
269 
270   if (!Add)
271     for (CeedInt q=0; q<A*J*C; q++)
272       v[q] = (CeedScalar) 0.0;
273 
274   if (C == 1) {
275     // Serial C=1 Case
276     CeedTensorContract_Avx_Single_4_8(contract, A, B, C, J, t, tmode, true, u,
277                                       v);
278   } else {
279     // Blocks of 8 columns
280     if (C >= blksize)
281       CeedTensorContract_Avx_Blocked_4_8(contract, A, B, C, J, t, tmode, true,
282                                          u, v);
283     // Remainder of columns
284     if (C % blksize)
285       CeedTensorContract_Avx_Remainder_8_8(contract, A, B, C, J, t, tmode, true,
286                                            u, v);
287   }
288 
289   return 0;
290 }
291 
292 static int CeedTensorContractDestroy_Avx(CeedTensorContract contract) {
293   return 0;
294 }
295 
296 int CeedTensorContractCreate_Avx(CeedBasis basis, CeedTensorContract contract) {
297   int ierr;
298   Ceed ceed;
299   ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChk(ierr);
300 
301   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply",
302                                 CeedTensorContractApply_Avx); CeedChk(ierr);
303   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy",
304                                 CeedTensorContractDestroy_Avx); CeedChk(ierr);
305 
306   return 0;
307 }
308