xref: /libCEED/backends/avx/ceed-avx-tensor.c (revision 346c77e6436e93de99b1714e06a264fc70d47960)
1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <immintrin.h>
11 #include <stdbool.h>
12 
13 #ifdef CEED_SCALAR_IS_FP64
14 #define rtype __m256d
15 #define loadu _mm256_loadu_pd
16 #define storeu _mm256_storeu_pd
17 #define set _mm256_set_pd
18 #define set1 _mm256_set1_pd
19 // c += a * b
20 #ifdef __FMA__
21 #define fmadd(c, a, b) (c) = _mm256_fmadd_pd((a), (b), (c))
22 #else
23 #define fmadd(c, a, b) (c) += _mm256_mul_pd((a), (b))
24 #endif
25 #else
26 #define rtype __m128
27 #define loadu _mm_loadu_ps
28 #define storeu _mm_storeu_ps
29 #define set _mm_set_ps
30 #define set1 _mm_set1_ps
31 // c += a * b
32 #ifdef __FMA__
33 #define fmadd(c, a, b) (c) = _mm_fmadd_ps((a), (b), (c))
34 #else
35 #define fmadd(c, a, b) (c) += _mm_mul_ps((a), (b))
36 #endif
37 #endif
38 
39 //------------------------------------------------------------------------------
40 // Blocked Tensor Contract
41 //------------------------------------------------------------------------------
42 static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J,
43                                                  const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add,
44                                                  const CeedScalar *restrict u, CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
45   CeedInt t_stride_0 = B, t_stride_1 = 1;
46 
47   if (t_mode == CEED_TRANSPOSE) {
48     t_stride_0 = 1;
49     t_stride_1 = J;
50   }
51 
52   for (CeedInt a = 0; a < A; a++) {
53     // Blocks of 4 rows
54     for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) {
55       for (CeedInt c = 0; c < (C / CC) * CC; c += CC) {
56         rtype vv[JJ][CC / 4];  // Output tile to be held in registers
57         for (CeedInt jj = 0; jj < JJ; jj++) {
58           for (CeedInt cc = 0; cc < CC / 4; cc++) vv[jj][cc] = loadu(&v[(a * J + j + jj) * C + c + cc * 4]);
59         }
60         for (CeedInt b = 0; b < B; b++) {
61           for (CeedInt jj = 0; jj < JJ; jj++) {  // unroll
62             rtype tqv = set1(t[(j + jj) * t_stride_0 + b * t_stride_1]);
63             for (CeedInt cc = 0; cc < CC / 4; cc++) {  // unroll
64               fmadd(vv[jj][cc], tqv, loadu(&u[(a * B + b) * C + c + cc * 4]));
65             }
66           }
67         }
68         for (CeedInt jj = 0; jj < JJ; jj++) {
69           for (CeedInt cc = 0; cc < CC / 4; cc++) storeu(&v[(a * J + j + jj) * C + c + cc * 4], vv[jj][cc]);
70         }
71       }
72     }
73     // Remainder of rows
74     const CeedInt j = (J / JJ) * JJ;
75 
76     if (j < J) {
77       for (CeedInt c = 0; c < (C / CC) * CC; c += CC) {
78         rtype vv[JJ][CC / 4];  // Output tile to be held in registers
79 
80         for (CeedInt jj = 0; jj < J - j; jj++) {
81           for (CeedInt cc = 0; cc < CC / 4; cc++) vv[jj][cc] = loadu(&v[(a * J + j + jj) * C + c + cc * 4]);
82         }
83         for (CeedInt b = 0; b < B; b++) {
84           for (CeedInt jj = 0; jj < J - j; jj++) {  // doesn't unroll
85             rtype tqv = set1(t[(j + jj) * t_stride_0 + b * t_stride_1]);
86 
87             for (CeedInt cc = 0; cc < CC / 4; cc++) {  // unroll
88               fmadd(vv[jj][cc], tqv, loadu(&u[(a * B + b) * C + c + cc * 4]));
89             }
90           }
91         }
92         for (CeedInt jj = 0; jj < J - j; jj++) {
93           for (CeedInt cc = 0; cc < CC / 4; cc++) storeu(&v[(a * J + j + jj) * C + c + cc * 4], vv[jj][cc]);
94         }
95       }
96     }
97   }
98   return CEED_ERROR_SUCCESS;
99 }
100 
101 //------------------------------------------------------------------------------
102 // Serial Tensor Contract Remainder
103 //------------------------------------------------------------------------------
104 static inline int CeedTensorContract_Avx_Remainder(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J,
105                                                    const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add,
106                                                    const CeedScalar *restrict u, CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
107   CeedInt t_stride_0 = B, t_stride_1 = 1;
108 
109   if (t_mode == CEED_TRANSPOSE) {
110     t_stride_0 = 1;
111     t_stride_1 = J;
112   }
113 
114   const CeedInt J_break = J % JJ ? (J / JJ) * JJ : (J / JJ - 1) * JJ;
115 
116   for (CeedInt a = 0; a < A; a++) {
117     // Blocks of 4 columns
118     for (CeedInt c = (C / CC) * CC; c < C; c += 4) {
119       // Blocks of 4 rows
120       for (CeedInt j = 0; j < J_break; j += JJ) {
121         rtype vv[JJ];  // Output tile to be held in registers
122 
123         for (CeedInt jj = 0; jj < JJ; jj++) vv[jj] = loadu(&v[(a * J + j + jj) * C + c]);
124         for (CeedInt b = 0; b < B; b++) {
125           rtype tqu;
126 
127           if (C - c == 1) tqu = set(0.0, 0.0, 0.0, u[(a * B + b) * C + c + 0]);
128           else if (C - c == 2) tqu = set(0.0, 0.0, u[(a * B + b) * C + c + 1], u[(a * B + b) * C + c + 0]);
129           else if (C - c == 3) tqu = set(0.0, u[(a * B + b) * C + c + 2], u[(a * B + b) * C + c + 1], u[(a * B + b) * C + c + 0]);
130           else tqu = loadu(&u[(a * B + b) * C + c]);
131           for (CeedInt jj = 0; jj < JJ; jj++) {  // unroll
132             fmadd(vv[jj], tqu, set1(t[(j + jj) * t_stride_0 + b * t_stride_1]));
133           }
134         }
135         for (CeedInt jj = 0; jj < JJ; jj++) storeu(&v[(a * J + j + jj) * C + c], vv[jj]);
136       }
137     }
138     // Remainder of rows, all columns
139     for (CeedInt j = J_break; j < J; j++) {
140       for (CeedInt b = 0; b < B; b++) {
141         const CeedScalar tq = t[j * t_stride_0 + b * t_stride_1];
142 
143         for (CeedInt c = (C / CC) * CC; c < C; c++) v[(a * J + j) * C + c] += tq * u[(a * B + b) * C + c];
144       }
145     }
146   }
147   return CEED_ERROR_SUCCESS;
148 }
149 
150 //------------------------------------------------------------------------------
151 // Serial Tensor Contract C=1
152 //------------------------------------------------------------------------------
153 static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
154                                                 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v,
155                                                 const CeedInt AA, const CeedInt JJ) {
156   CeedInt t_stride_0 = B, t_stride_1 = 1;
157 
158   if (t_mode == CEED_TRANSPOSE) {
159     t_stride_0 = 1;
160     t_stride_1 = J;
161   }
162 
163   // Blocks of 4 rows
164   for (CeedInt a = 0; a < (A / AA) * AA; a += AA) {
165     for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) {
166       rtype vv[AA][JJ / 4];  // Output tile to be held in registers
167 
168       for (CeedInt aa = 0; aa < AA; aa++) {
169         for (CeedInt jj = 0; jj < JJ / 4; jj++) vv[aa][jj] = loadu(&v[(a + aa) * J + j + jj * 4]);
170       }
171       for (CeedInt b = 0; b < B; b++) {
172         for (CeedInt jj = 0; jj < JJ / 4; jj++) {  // unroll
173           rtype tqv = set(t[(j + jj * 4 + 3) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 2) * t_stride_0 + b * t_stride_1],
174                           t[(j + jj * 4 + 1) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 0) * t_stride_0 + b * t_stride_1]);
175 
176           for (CeedInt aa = 0; aa < AA; aa++) {  // unroll
177             fmadd(vv[aa][jj], tqv, set1(u[(a + aa) * B + b]));
178           }
179         }
180       }
181       for (CeedInt aa = 0; aa < AA; aa++) {
182         for (CeedInt jj = 0; jj < JJ / 4; jj++) storeu(&v[(a + aa) * J + j + jj * 4], vv[aa][jj]);
183       }
184     }
185   }
186   // Remainder of rows
187   const CeedInt a = (A / AA) * AA;
188 
189   for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) {
190     rtype vv[AA][JJ / 4];  // Output tile to be held in registers
191 
192     for (CeedInt aa = 0; aa < A - a; aa++) {
193       for (CeedInt jj = 0; jj < JJ / 4; jj++) vv[aa][jj] = loadu(&v[(a + aa) * J + j + jj * 4]);
194     }
195     for (CeedInt b = 0; b < B; b++) {
196       for (CeedInt jj = 0; jj < JJ / 4; jj++) {  // unroll
197         rtype tqv = set(t[(j + jj * 4 + 3) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 2) * t_stride_0 + b * t_stride_1],
198                         t[(j + jj * 4 + 1) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 0) * t_stride_0 + b * t_stride_1]);
199 
200         for (CeedInt aa = 0; aa < A - a; aa++) {  // unroll
201           fmadd(vv[aa][jj], tqv, set1(u[(a + aa) * B + b]));
202         }
203       }
204     }
205     for (CeedInt aa = 0; aa < A - a; aa++) {
206       for (CeedInt jj = 0; jj < JJ / 4; jj++) storeu(&v[(a + aa) * J + j + jj * 4], vv[aa][jj]);
207     }
208   }
209   // Column remainder
210   const CeedInt A_break = A % AA ? (A / AA) * AA : (A / AA - 1) * AA;
211 
212   // Blocks of 4 columns
213   for (CeedInt j = (J / JJ) * JJ; j < J; j += 4) {
214     // Blocks of 4 rows
215     for (CeedInt a = 0; a < A_break; a += AA) {
216       rtype vv[AA];  // Output tile to be held in registers
217 
218       for (CeedInt aa = 0; aa < AA; aa++) vv[aa] = loadu(&v[(a + aa) * J + j]);
219       for (CeedInt b = 0; b < B; b++) {
220         rtype tqv;
221 
222         if (J - j == 1) {
223           tqv = set(0.0, 0.0, 0.0, t[(j + 0) * t_stride_0 + b * t_stride_1]);
224         } else if (J - j == 2) {
225           tqv = set(0.0, 0.0, t[(j + 1) * t_stride_0 + b * t_stride_1], t[(j + 0) * t_stride_0 + b * t_stride_1]);
226         } else if (J - 3 == j) {
227           tqv =
228               set(0.0, t[(j + 2) * t_stride_0 + b * t_stride_1], t[(j + 1) * t_stride_0 + b * t_stride_1], t[(j + 0) * t_stride_0 + b * t_stride_1]);
229         } else {
230           tqv = set(t[(j + 3) * t_stride_0 + b * t_stride_1], t[(j + 2) * t_stride_0 + b * t_stride_1], t[(j + 1) * t_stride_0 + b * t_stride_1],
231                     t[(j + 0) * t_stride_0 + b * t_stride_1]);
232         }
233         for (CeedInt aa = 0; aa < AA; aa++) {  // unroll
234           fmadd(vv[aa], tqv, set1(u[(a + aa) * B + b]));
235         }
236       }
237       for (CeedInt aa = 0; aa < AA; aa++) storeu(&v[(a + aa) * J + j], vv[aa]);
238     }
239   }
240   // Remainder of rows, all columns
241   for (CeedInt b = 0; b < B; b++) {
242     for (CeedInt j = (J / JJ) * JJ; j < J; j++) {
243       const CeedScalar tq = t[j * t_stride_0 + b * t_stride_1];
244 
245       for (CeedInt a = A_break; a < A; a++) v[a * J + j] += tq * u[a * B + b];
246     }
247   }
248   return CEED_ERROR_SUCCESS;
249 }
250 
251 //------------------------------------------------------------------------------
252 // Tensor Contract - Common Sizes
253 //------------------------------------------------------------------------------
254 static int CeedTensorContract_Avx_Blocked_4_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
255                                               CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
256   return CeedTensorContract_Avx_Blocked(contract, A, B, C, J, t, t_mode, add, u, v, 4, 8);
257 }
258 static int CeedTensorContract_Avx_Remainder_8_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
259                                                 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
260   return CeedTensorContract_Avx_Remainder(contract, A, B, C, J, t, t_mode, add, u, v, 8, 8);
261 }
262 static int CeedTensorContract_Avx_Single_4_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
263                                              CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
264   return CeedTensorContract_Avx_Single(contract, A, B, C, J, t, t_mode, add, u, v, 4, 8);
265 }
266 
267 //------------------------------------------------------------------------------
268 // Tensor Contract Apply
269 //------------------------------------------------------------------------------
270 static int CeedTensorContractApply_Avx(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
271                                        CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
272   const CeedInt blk_size = 8;
273 
274   if (!add) {
275     for (CeedInt q = 0; q < A * J * C; q++) v[q] = (CeedScalar)0.0;
276   }
277 
278   if (C == 1) {
279     // Serial C=1 Case
280     CeedTensorContract_Avx_Single_4_8(contract, A, B, C, J, t, t_mode, true, u, v);
281   } else {
282     // Blocks of 8 columns
283     if (C >= blk_size) CeedTensorContract_Avx_Blocked_4_8(contract, A, B, C, J, t, t_mode, true, u, v);
284     // Remainder of columns
285     if (C % blk_size) CeedTensorContract_Avx_Remainder_8_8(contract, A, B, C, J, t, t_mode, true, u, v);
286   }
287   return CEED_ERROR_SUCCESS;
288 }
289 
290 //------------------------------------------------------------------------------
291 // Tensor Contract Create
292 //------------------------------------------------------------------------------
293 int CeedTensorContractCreate_Avx(CeedTensorContract contract) {
294   CeedCallBackend(CeedSetBackendFunction(CeedTensorContractReturnCeed(contract), "TensorContract", contract, "Apply", CeedTensorContractApply_Avx));
295   return CEED_ERROR_SUCCESS;
296 }
297 
298 //------------------------------------------------------------------------------
299