xref: /libCEED/backends/avx/ceed-avx-tensor.c (revision 7113573b6efd54558bb98b919dff5d6d8ffcff54)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <immintrin.h>
11 #include <stdbool.h>
12 
13 #ifdef _ceed_f64_h
14 #define rtype __m256d
15 #define loadu _mm256_loadu_pd
16 #define storeu _mm256_storeu_pd
17 #define set _mm256_set_pd
18 #define set1 _mm256_set1_pd
19 // c += a * b
20 #ifdef __FMA__
21 #define fmadd(c, a, b) (c) = _mm256_fmadd_pd((a), (b), (c))
22 #else
23 #define fmadd(c, a, b) (c) += _mm256_mul_pd((a), (b))
24 #endif
25 #else
26 #define rtype __m128
27 #define loadu _mm_loadu_ps
28 #define storeu _mm_storeu_ps
29 #define set _mm_set_ps
30 #define set1 _mm_set1_ps
31 // c += a * b
32 #ifdef __FMA__
33 #define fmadd(c, a, b) (c) = _mm_fmadd_ps((a), (b), (c))
34 #else
35 #define fmadd(c, a, b) (c) += _mm_mul_ps((a), (b))
36 #endif
37 #endif
38 
39 //------------------------------------------------------------------------------
40 // Blocked Tensor Contract
41 //------------------------------------------------------------------------------
42 static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J,
43                                                  const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add,
44                                                  const CeedScalar *restrict u, CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
45   CeedInt t_stride_0 = B, t_stride_1 = 1;
46   if (t_mode == CEED_TRANSPOSE) {
47     t_stride_0 = 1;
48     t_stride_1 = J;
49   }
50 
51   for (CeedInt a = 0; a < A; a++) {
52     // Blocks of 4 rows
53     for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) {
54       for (CeedInt c = 0; c < (C / CC) * CC; c += CC) {
55         rtype vv[JJ][CC / 4];  // Output tile to be held in registers
56         for (CeedInt jj = 0; jj < JJ; jj++) {
57           for (CeedInt cc = 0; cc < CC / 4; cc++) vv[jj][cc] = loadu(&v[(a * J + j + jj) * C + c + cc * 4]);
58         }
59 
60         for (CeedInt b = 0; b < B; b++) {
61           for (CeedInt jj = 0; jj < JJ; jj++) {  // unroll
62             rtype tqv = set1(t[(j + jj) * t_stride_0 + b * t_stride_1]);
63             for (CeedInt cc = 0; cc < CC / 4; cc++) {  // unroll
64               fmadd(vv[jj][cc], tqv, loadu(&u[(a * B + b) * C + c + cc * 4]));
65             }
66           }
67         }
68         for (CeedInt jj = 0; jj < JJ; jj++) {
69           for (CeedInt cc = 0; cc < CC / 4; cc++) storeu(&v[(a * J + j + jj) * C + c + cc * 4], vv[jj][cc]);
70         }
71       }
72     }
73     // Remainder of rows
74     CeedInt j = (J / JJ) * JJ;
75     if (j < J) {
76       for (CeedInt c = 0; c < (C / CC) * CC; c += CC) {
77         rtype vv[JJ][CC / 4];  // Output tile to be held in registers
78         for (CeedInt jj = 0; jj < J - j; jj++) {
79           for (CeedInt cc = 0; cc < CC / 4; cc++) vv[jj][cc] = loadu(&v[(a * J + j + jj) * C + c + cc * 4]);
80         }
81 
82         for (CeedInt b = 0; b < B; b++) {
83           for (CeedInt jj = 0; jj < J - j; jj++) {  // doesn't unroll
84             rtype tqv = set1(t[(j + jj) * t_stride_0 + b * t_stride_1]);
85             for (CeedInt cc = 0; cc < CC / 4; cc++) {  // unroll
86               fmadd(vv[jj][cc], tqv, loadu(&u[(a * B + b) * C + c + cc * 4]));
87             }
88           }
89         }
90         for (CeedInt jj = 0; jj < J - j; jj++) {
91           for (CeedInt cc = 0; cc < CC / 4; cc++) storeu(&v[(a * J + j + jj) * C + c + cc * 4], vv[jj][cc]);
92         }
93       }
94     }
95   }
96   return CEED_ERROR_SUCCESS;
97 }
98 
99 //------------------------------------------------------------------------------
100 // Serial Tensor Contract Remainder
101 //------------------------------------------------------------------------------
102 static inline int CeedTensorContract_Avx_Remainder(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J,
103                                                    const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add,
104                                                    const CeedScalar *restrict u, CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
105   CeedInt t_stride_0 = B, t_stride_1 = 1;
106   if (t_mode == CEED_TRANSPOSE) {
107     t_stride_0 = 1;
108     t_stride_1 = J;
109   }
110 
111   CeedInt J_break = J % JJ ? (J / JJ) * JJ : (J / JJ - 1) * JJ;
112   for (CeedInt a = 0; a < A; a++) {
113     // Blocks of 4 columns
114     for (CeedInt c = (C / CC) * CC; c < C; c += 4) {
115       // Blocks of 4 rows
116       for (CeedInt j = 0; j < J_break; j += JJ) {
117         rtype vv[JJ];  // Output tile to be held in registers
118         for (CeedInt jj = 0; jj < JJ; jj++) vv[jj] = loadu(&v[(a * J + j + jj) * C + c]);
119 
120         for (CeedInt b = 0; b < B; b++) {
121           rtype tqu;
122           if (C - c == 1) tqu = set(0.0, 0.0, 0.0, u[(a * B + b) * C + c + 0]);
123           else if (C - c == 2) tqu = set(0.0, 0.0, u[(a * B + b) * C + c + 1], u[(a * B + b) * C + c + 0]);
124           else if (C - c == 3) tqu = set(0.0, u[(a * B + b) * C + c + 2], u[(a * B + b) * C + c + 1], u[(a * B + b) * C + c + 0]);
125           else tqu = loadu(&u[(a * B + b) * C + c]);
126           for (CeedInt jj = 0; jj < JJ; jj++) {  // unroll
127             fmadd(vv[jj], tqu, set1(t[(j + jj) * t_stride_0 + b * t_stride_1]));
128           }
129         }
130         for (CeedInt jj = 0; jj < JJ; jj++) storeu(&v[(a * J + j + jj) * C + c], vv[jj]);
131       }
132     }
133     // Remainder of rows, all columns
134     for (CeedInt j = J_break; j < J; j++) {
135       for (CeedInt b = 0; b < B; b++) {
136         CeedScalar tq = t[j * t_stride_0 + b * t_stride_1];
137         for (CeedInt c = (C / CC) * CC; c < C; c++) v[(a * J + j) * C + c] += tq * u[(a * B + b) * C + c];
138       }
139     }
140   }
141   return CEED_ERROR_SUCCESS;
142 }
143 
144 //------------------------------------------------------------------------------
145 // Serial Tensor Contract C=1
146 //------------------------------------------------------------------------------
147 static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
148                                                 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v,
149                                                 const CeedInt AA, const CeedInt JJ) {
150   CeedInt t_stride_0 = B, t_stride_1 = 1;
151   if (t_mode == CEED_TRANSPOSE) {
152     t_stride_0 = 1;
153     t_stride_1 = J;
154   }
155 
156   // Blocks of 4 rows
157   for (CeedInt a = 0; a < (A / AA) * AA; a += AA) {
158     for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) {
159       rtype vv[AA][JJ / 4];  // Output tile to be held in registers
160       for (CeedInt aa = 0; aa < AA; aa++) {
161         for (CeedInt jj = 0; jj < JJ / 4; jj++) vv[aa][jj] = loadu(&v[(a + aa) * J + j + jj * 4]);
162       }
163 
164       for (CeedInt b = 0; b < B; b++) {
165         for (CeedInt jj = 0; jj < JJ / 4; jj++) {  // unroll
166           rtype tqv = set(t[(j + jj * 4 + 3) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 2) * t_stride_0 + b * t_stride_1],
167                           t[(j + jj * 4 + 1) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 0) * t_stride_0 + b * t_stride_1]);
168           for (CeedInt aa = 0; aa < AA; aa++) {  // unroll
169             fmadd(vv[aa][jj], tqv, set1(u[(a + aa) * B + b]));
170           }
171         }
172       }
173       for (CeedInt aa = 0; aa < AA; aa++) {
174         for (CeedInt jj = 0; jj < JJ / 4; jj++) storeu(&v[(a + aa) * J + j + jj * 4], vv[aa][jj]);
175       }
176     }
177   }
178   // Remainder of rows
179   CeedInt a = (A / AA) * AA;
180   for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) {
181     rtype vv[AA][JJ / 4];  // Output tile to be held in registers
182     for (CeedInt aa = 0; aa < A - a; aa++) {
183       for (CeedInt jj = 0; jj < JJ / 4; jj++) vv[aa][jj] = loadu(&v[(a + aa) * J + j + jj * 4]);
184     }
185 
186     for (CeedInt b = 0; b < B; b++) {
187       for (CeedInt jj = 0; jj < JJ / 4; jj++) {  // unroll
188         rtype tqv = set(t[(j + jj * 4 + 3) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 2) * t_stride_0 + b * t_stride_1],
189                         t[(j + jj * 4 + 1) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 0) * t_stride_0 + b * t_stride_1]);
190         for (CeedInt aa = 0; aa < A - a; aa++) {  // unroll
191           fmadd(vv[aa][jj], tqv, set1(u[(a + aa) * B + b]));
192         }
193       }
194     }
195     for (CeedInt aa = 0; aa < A - a; aa++) {
196       for (CeedInt jj = 0; jj < JJ / 4; jj++) storeu(&v[(a + aa) * J + j + jj * 4], vv[aa][jj]);
197     }
198   }
199   // Column remainder
200   CeedInt A_break = A % AA ? (A / AA) * AA : (A / AA - 1) * AA;
201   // Blocks of 4 columns
202   for (CeedInt j = (J / JJ) * JJ; j < J; j += 4) {
203     // Blocks of 4 rows
204     for (CeedInt a = 0; a < A_break; a += AA) {
205       rtype vv[AA];  // Output tile to be held in registers
206       for (CeedInt aa = 0; aa < AA; aa++) vv[aa] = loadu(&v[(a + aa) * J + j]);
207 
208       for (CeedInt b = 0; b < B; b++) {
209         rtype tqv;
210         if (J - j == 1) {
211           tqv = set(0.0, 0.0, 0.0, t[(j + 0) * t_stride_0 + b * t_stride_1]);
212         } else if (J - j == 2) {
213           tqv = set(0.0, 0.0, t[(j + 1) * t_stride_0 + b * t_stride_1], t[(j + 0) * t_stride_0 + b * t_stride_1]);
214         } else if (J - 3 == j) {
215           tqv =
216               set(0.0, t[(j + 2) * t_stride_0 + b * t_stride_1], t[(j + 1) * t_stride_0 + b * t_stride_1], t[(j + 0) * t_stride_0 + b * t_stride_1]);
217         } else {
218           tqv = set(t[(j + 3) * t_stride_0 + b * t_stride_1], t[(j + 2) * t_stride_0 + b * t_stride_1], t[(j + 1) * t_stride_0 + b * t_stride_1],
219                     t[(j + 0) * t_stride_0 + b * t_stride_1]);
220         }
221         for (CeedInt aa = 0; aa < AA; aa++) {  // unroll
222           fmadd(vv[aa], tqv, set1(u[(a + aa) * B + b]));
223         }
224       }
225       for (CeedInt aa = 0; aa < AA; aa++) storeu(&v[(a + aa) * J + j], vv[aa]);
226     }
227   }
228   // Remainder of rows, all columns
229   for (CeedInt b = 0; b < B; b++) {
230     for (CeedInt j = (J / JJ) * JJ; j < J; j++) {
231       CeedScalar tq = t[j * t_stride_0 + b * t_stride_1];
232       for (CeedInt a = A_break; a < A; a++) v[a * J + j] += tq * u[a * B + b];
233     }
234   }
235   return CEED_ERROR_SUCCESS;
236 }
237 
238 //------------------------------------------------------------------------------
239 // Tensor Contract - Common Sizes
240 //------------------------------------------------------------------------------
241 static int CeedTensorContract_Avx_Blocked_4_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
242                                               CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
243   return CeedTensorContract_Avx_Blocked(contract, A, B, C, J, t, t_mode, add, u, v, 4, 8);
244 }
245 static int CeedTensorContract_Avx_Remainder_8_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
246                                                 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
247   return CeedTensorContract_Avx_Remainder(contract, A, B, C, J, t, t_mode, add, u, v, 8, 8);
248 }
249 static int CeedTensorContract_Avx_Single_4_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
250                                              CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
251   return CeedTensorContract_Avx_Single(contract, A, B, C, J, t, t_mode, add, u, v, 4, 8);
252 }
253 
254 //------------------------------------------------------------------------------
255 // Tensor Contract Apply
256 //------------------------------------------------------------------------------
257 static int CeedTensorContractApply_Avx(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
258                                        CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
259   const CeedInt blk_size = 8;
260 
261   if (!add) {
262     for (CeedInt q = 0; q < A * J * C; q++) v[q] = (CeedScalar)0.0;
263   }
264 
265   if (C == 1) {
266     // Serial C=1 Case
267     CeedTensorContract_Avx_Single_4_8(contract, A, B, C, J, t, t_mode, true, u, v);
268   } else {
269     // Blocks of 8 columns
270     if (C >= blk_size) CeedTensorContract_Avx_Blocked_4_8(contract, A, B, C, J, t, t_mode, true, u, v);
271     // Remainder of columns
272     if (C % blk_size) CeedTensorContract_Avx_Remainder_8_8(contract, A, B, C, J, t, t_mode, true, u, v);
273   }
274 
275   return CEED_ERROR_SUCCESS;
276 }
277 
278 //------------------------------------------------------------------------------
279 // Tensor Contract Create
280 //------------------------------------------------------------------------------
281 int CeedTensorContractCreate_Avx(CeedBasis basis, CeedTensorContract contract) {
282   Ceed ceed;
283   CeedCallBackend(CeedTensorContractGetCeed(contract, &ceed));
284 
285   CeedCallBackend(CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", CeedTensorContractApply_Avx));
286 
287   return CEED_ERROR_SUCCESS;
288 }
289 
290 //------------------------------------------------------------------------------
291