xref: /libCEED/backends/avx/ceed-avx-tensor.c (revision cb32e2e7f026784d97a57f1901677e9727def907)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include "ceed-avx.h"
18 
19 // Blocked Tensor Contact
20 static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract,
21     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
22     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
23     CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
24   CeedInt tstride0 = B, tstride1 = 1;
25   if (tmode == CEED_TRANSPOSE) {
26     tstride0 = 1; tstride1 = J;
27   }
28 
29   for (CeedInt a=0; a<A; a++) {
30     // Blocks of 4 rows
31     for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) {
32       for (CeedInt c=0; c<(C/CC)*CC; c+=CC) {
33         __m256d vv[JJ][CC/4]; // Output tile to be held in registers
34         for (CeedInt jj=0; jj<JJ; jj++)
35           for (CeedInt cc=0; cc<CC/4; cc++)
36             vv[jj][cc] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c+cc*4]);
37 
38         for (CeedInt b=0; b<B; b++) {
39           for (CeedInt jj=0; jj<JJ; jj++) { // unroll
40             __m256d tqv = _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1]);
41             for (CeedInt cc=0; cc<CC/4; cc++) // unroll
42               vv[jj][cc] += _mm256_mul_pd(tqv,
43                                           _mm256_loadu_pd(&u[(a*B+b)*C+c+cc*4]));
44           }
45         }
46         for (CeedInt jj=0; jj<JJ; jj++)
47           for (CeedInt cc=0; cc<CC/4; cc++)
48             _mm256_storeu_pd(&v[(a*J+j+jj)*C+c+cc*4], vv[jj][cc]);
49       }
50     }
51     // Remainder of rows
52     CeedInt j=(J/JJ)*JJ;
53     if (j < J) {
54       for (CeedInt c=0; c<(C/CC)*CC; c+=CC) {
55         __m256d vv[JJ][CC/4]; // Output tile to be held in registers
56         for (CeedInt jj=0; jj<J-j; jj++)
57           for (CeedInt cc=0; cc<CC/4; cc++)
58             vv[jj][cc] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c+cc*4]);
59 
60         for (CeedInt b=0; b<B; b++) {
61           for (CeedInt jj=0; jj<J-j; jj++) { // doesn't unroll
62             __m256d tqv = _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1]);
63             for (CeedInt cc=0; cc<CC/4; cc++) // unroll
64               vv[jj][cc] += _mm256_mul_pd(tqv,
65                                           _mm256_loadu_pd(&u[(a*B+b)*C+c+cc*4]));
66           }
67         }
68         for (CeedInt jj=0; jj<J-j; jj++)
69           for (CeedInt cc=0; cc<CC/4; cc++)
70             _mm256_storeu_pd(&v[(a*J+j+jj)*C+c+cc*4], vv[jj][cc]);
71       }
72     }
73   }
74   return 0;
75 }
76 
77 // Serial Tensor Contract Remainder
78 static inline int CeedTensorContract_Avx_Remainder(CeedTensorContract contract,
79     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
80     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
81     CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
82   CeedInt tstride0 = B, tstride1 = 1;
83   if (tmode == CEED_TRANSPOSE) {
84     tstride0 = 1; tstride1 = J;
85   }
86 
87   CeedInt Jbreak = J%JJ ? (J/JJ)*JJ : (J/JJ-1)*JJ;
88   for (CeedInt a=0; a<A; a++) {
89     // Blocks of 4 columns
90     for (CeedInt c = (C/CC)*CC; c<C; c+=4) {
91       // Blocks of 4 rows
92       for (CeedInt j=0; j<Jbreak; j+=JJ) {
93         __m256d vv[JJ]; // Output tile to be held in registers
94         for (CeedInt jj=0; jj<JJ; jj++)
95           vv[jj] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c]);
96 
97         for (CeedInt b=0; b<B; b++) {
98           __m256d tqu;
99           if (C-c == 1)
100             tqu = _mm256_set_pd(0.0, 0.0, 0.0, u[(a*B+b)*C+c+0]);
101           else if (C-c == 2)
102             tqu = _mm256_set_pd(0.0, 0.0, u[(a*B+b)*C+c+1],
103                                 u[(a*B+b)*C+c+0]);
104           else if (C-c == 3)
105             tqu = _mm256_set_pd(0.0, u[(a*B+b)*C+c+2], u[(a*B+b)*C+c+1],
106                                 u[(a*B+b)*C+c+0]);
107           else
108             tqu = _mm256_loadu_pd(&u[(a*B+b)*C+c]);
109           for (CeedInt jj=0; jj<JJ; jj++) // unroll
110             vv[jj] += _mm256_mul_pd(tqu,
111                                     _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1]));
112         }
113         for (CeedInt jj=0; jj<JJ; jj++)
114           _mm256_storeu_pd(&v[(a*J+j+jj)*C+c], vv[jj]);
115       }
116     }
117     // Remainder of rows, all columns
118     for (CeedInt j=Jbreak; j<J; j++)
119       for (CeedInt b=0; b<B; b++) {
120         CeedScalar tq = t[j*tstride0 + b*tstride1];
121         for (CeedInt c=(C/CC)*CC; c<C; c++)
122           v[(a*J+j)*C+c] += tq * u[(a*B+b)*C+c];
123       }
124   }
125   return 0;
126 }
127 
128 // Serial Tensor Contract C=1 Case
129 static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract,
130     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
131     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
132     CeedScalar *restrict v, const CeedInt AA, const CeedInt JJ) {
133   CeedInt tstride0 = B, tstride1 = 1;
134   if (tmode == CEED_TRANSPOSE) {
135     tstride0 = 1; tstride1 = J;
136   }
137 
138   // Blocks of 4 rows
139   for (CeedInt a=0; a<(A/AA)*AA; a+=AA) {
140     for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) {
141       __m256d vv[AA][JJ/4]; // Output tile to be held in registers
142       for (CeedInt aa=0; aa<AA; aa++)
143         for (CeedInt jj=0; jj<JJ/4; jj++)
144           vv[aa][jj] = _mm256_loadu_pd(&v[(a+aa)*J+j+jj*4]);
145 
146       for (CeedInt b=0; b<B; b++) {
147         for (CeedInt jj=0; jj<JJ/4; jj++) { // unroll
148           __m256d tqv = _mm256_set_pd(t[(j+jj*4+3)*tstride0 + b*tstride1],
149                                       t[(j+jj*4+2)*tstride0 + b*tstride1],
150                                       t[(j+jj*4+1)*tstride0 + b*tstride1],
151                                       t[(j+jj*4+0)*tstride0 + b*tstride1]);
152           for (CeedInt aa=0; aa<AA; aa++) // unroll
153             vv[aa][jj] += _mm256_mul_pd(tqv, _mm256_set1_pd(u[(a+aa)*B+b]));
154         }
155       }
156       for (CeedInt aa=0; aa<AA; aa++)
157         for (CeedInt jj=0; jj<JJ/4; jj++)
158           _mm256_storeu_pd(&v[(a+aa)*J+j+jj*4], vv[aa][jj]);
159     }
160   }
161   // Remainder of rows
162   CeedInt a=(A/AA)*AA;
163   for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) {
164     __m256d vv[AA][JJ/4]; // Output tile to be held in registers
165     for (CeedInt aa=0; aa<A-a; aa++)
166       for (CeedInt jj=0; jj<JJ/4; jj++)
167         vv[aa][jj] = _mm256_loadu_pd(&v[(a+aa)*J+j+jj*4]);
168 
169     for (CeedInt b=0; b<B; b++) {
170       for (CeedInt jj=0; jj<JJ/4; jj++) { // unroll
171         __m256d tqv = _mm256_set_pd(t[(j+jj*4+3)*tstride0 + b*tstride1],
172                                     t[(j+jj*4+2)*tstride0 + b*tstride1],
173                                     t[(j+jj*4+1)*tstride0 + b*tstride1],
174                                     t[(j+jj*4+0)*tstride0 + b*tstride1]);
175         for (CeedInt aa=0; aa<A-a; aa++) // unroll
176           vv[aa][jj] += _mm256_mul_pd(tqv, _mm256_set1_pd(u[(a+aa)*B+b]));
177       }
178     }
179     for (CeedInt aa=0; aa<A-a; aa++)
180       for (CeedInt jj=0; jj<JJ/4; jj++)
181         _mm256_storeu_pd(&v[(a+aa)*J+j+jj*4], vv[aa][jj]);
182   }
183   // Column remainder
184   CeedInt Abreak = A%AA ? (A/AA)*AA : (A/AA-1)*AA;
185   // Blocks of 4 columns
186   for (CeedInt j = (J/JJ)*JJ; j<J; j+=4) {
187     // Blocks of 4 rows
188     for (CeedInt a=0; a<Abreak; a+=AA) {
189       __m256d vv[AA]; // Output tile to be held in registers
190       for (CeedInt aa=0; aa<AA; aa++)
191         vv[aa] = _mm256_loadu_pd(&v[(a+aa)*J+j]);
192 
193       for (CeedInt b=0; b<B; b++) {
194         __m256d tqv;
195         if (J-j == 1)
196           tqv = _mm256_set_pd(0.0, 0.0, 0.0, t[(j+0)*tstride0 + b*tstride1]);
197         else if (J-j == 2)
198           tqv = _mm256_set_pd(0.0, 0.0, t[(j+1)*tstride0 + b*tstride1],
199                               t[(j+0)*tstride0 + b*tstride1]);
200         else if (J-3 == j)
201           tqv = _mm256_set_pd(0.0, t[(j+2)*tstride0 + b*tstride1],
202                               t[(j+1)*tstride0 + b*tstride1],
203                               t[(j+0)*tstride0 + b*tstride1]);
204         else
205           tqv = _mm256_set_pd(t[(j+3)*tstride0 + b*tstride1],
206                               t[(j+2)*tstride0 + b*tstride1],
207                               t[(j+1)*tstride0 + b*tstride1],
208                               t[(j+0)*tstride0 + b*tstride1]);
209         for (CeedInt aa=0; aa<AA; aa++) // unroll
210           vv[aa] += _mm256_mul_pd(tqv, _mm256_set1_pd(u[(a+aa)*B+b]));
211       }
212       for (CeedInt aa=0; aa<AA; aa++)
213         _mm256_storeu_pd(&v[(a+aa)*J+j], vv[aa]);
214     }
215   }
216   // Remainder of rows, all columns
217   for (CeedInt b=0; b<B; b++) {
218     for (CeedInt j=(J/JJ)*JJ; j<J; j++) {
219       CeedScalar tq = t[j*tstride0 + b*tstride1];
220       for (CeedInt a=Abreak; a<A; a++)
221         v[a*J+j] += tq * u[a*B+b];
222     }
223   }
224   return 0;
225 }
226 
227 // Specific Variants
228 static int CeedTensorContract_Avx_Blocked_4_8(CeedTensorContract contract,
229     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
230     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
231     CeedScalar *restrict v) {
232   return CeedTensorContract_Avx_Blocked(contract, A, B, C, J, t, tmode, Add, u,
233                                         v, 4, 8);
234 }
235 static int CeedTensorContract_Avx_Remainder_8_8(CeedTensorContract contract,
236     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
237     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
238     CeedScalar *restrict v) {
239   return CeedTensorContract_Avx_Remainder(contract, A, B, C, J, t, tmode, Add,
240                                           u, v, 8, 8);
241 }
242 static int CeedTensorContract_Avx_Single_4_8(CeedTensorContract contract,
243     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
244     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
245     CeedScalar *restrict v) {
246   return CeedTensorContract_Avx_Single(contract, A, B, C, J, t, tmode, Add, u,
247                                        v, 4, 8);
248 }
249 
250 // Switch for Tensor Contract
251 static int CeedTensorContractApply_Avx(CeedTensorContract contract, CeedInt A,
252                                        CeedInt B, CeedInt C, CeedInt J,
253                                        const CeedScalar *restrict t,
254                                        CeedTransposeMode tmode,
255                                        const CeedInt Add,
256                                        const CeedScalar *restrict u,
257                                        CeedScalar *restrict v) {
258   const CeedInt blksize = 8;
259 
260   if (!Add)
261     for (CeedInt q=0; q<A*J*C; q++)
262       v[q] = (CeedScalar) 0.0;
263 
264   if (C == 1) {
265     // Serial C=1 Case
266     CeedTensorContract_Avx_Single_4_8(contract, A, B, C, J, t, tmode, true, u,
267                                       v);
268   } else {
269     // Blocks of 8 columns
270     if (C >= blksize)
271       CeedTensorContract_Avx_Blocked_4_8(contract, A, B, C, J, t, tmode, true,
272                                          u, v);
273     // Remainder of columns
274     if (C % blksize)
275       CeedTensorContract_Avx_Remainder_8_8(contract, A, B, C, J, t, tmode, true,
276                                            u, v);
277   }
278 
279   return 0;
280 }
281 
282 static int CeedTensorContractDestroy_Avx(CeedTensorContract contract) {
283   return 0;
284 }
285 
286 int CeedTensorContractCreate_Avx(CeedBasis basis, CeedTensorContract contract) {
287   int ierr;
288   Ceed ceed;
289   ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChk(ierr);
290 
291   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply",
292                                 CeedTensorContractApply_Avx); CeedChk(ierr);
293   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy",
294                                 CeedTensorContractDestroy_Avx); CeedChk(ierr);
295 
296   return 0;
297 }
298