xref: /libCEED/rust/libceed-sys/c-src/backends/avx/ceed-avx-tensor.c (revision 94b7b29b41ad8a17add4c577886859ef16f89dec)
1c8a55531SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2c8a55531SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3c8a55531SSebastian Grimberg //
4c8a55531SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5c8a55531SSebastian Grimberg //
6c8a55531SSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7c8a55531SSebastian Grimberg 
8c8a55531SSebastian Grimberg #include <ceed.h>
9c8a55531SSebastian Grimberg #include <ceed/backend.h>
10c8a55531SSebastian Grimberg #include <immintrin.h>
11c8a55531SSebastian Grimberg #include <stdbool.h>
12c8a55531SSebastian Grimberg 
13*94b7b29bSJeremy L Thompson #ifdef CEED_F64_H
14c8a55531SSebastian Grimberg #define rtype __m256d
15c8a55531SSebastian Grimberg #define loadu _mm256_loadu_pd
16c8a55531SSebastian Grimberg #define storeu _mm256_storeu_pd
17c8a55531SSebastian Grimberg #define set _mm256_set_pd
18c8a55531SSebastian Grimberg #define set1 _mm256_set1_pd
19c8a55531SSebastian Grimberg // c += a * b
20c8a55531SSebastian Grimberg #ifdef __FMA__
21c8a55531SSebastian Grimberg #define fmadd(c, a, b) (c) = _mm256_fmadd_pd((a), (b), (c))
22c8a55531SSebastian Grimberg #else
23c8a55531SSebastian Grimberg #define fmadd(c, a, b) (c) += _mm256_mul_pd((a), (b))
24c8a55531SSebastian Grimberg #endif
25c8a55531SSebastian Grimberg #else
26c8a55531SSebastian Grimberg #define rtype __m128
27c8a55531SSebastian Grimberg #define loadu _mm_loadu_ps
28c8a55531SSebastian Grimberg #define storeu _mm_storeu_ps
29c8a55531SSebastian Grimberg #define set _mm_set_ps
30c8a55531SSebastian Grimberg #define set1 _mm_set1_ps
31c8a55531SSebastian Grimberg // c += a * b
32c8a55531SSebastian Grimberg #ifdef __FMA__
33c8a55531SSebastian Grimberg #define fmadd(c, a, b) (c) = _mm_fmadd_ps((a), (b), (c))
34c8a55531SSebastian Grimberg #else
35c8a55531SSebastian Grimberg #define fmadd(c, a, b) (c) += _mm_mul_ps((a), (b))
36c8a55531SSebastian Grimberg #endif
37c8a55531SSebastian Grimberg #endif
38c8a55531SSebastian Grimberg 
39c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
40c8a55531SSebastian Grimberg // Blocked Tensor Contract
41c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
42c8a55531SSebastian Grimberg static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J,
43c8a55531SSebastian Grimberg                                                  const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add,
44c8a55531SSebastian Grimberg                                                  const CeedScalar *restrict u, CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
45c8a55531SSebastian Grimberg   CeedInt t_stride_0 = B, t_stride_1 = 1;
46c8a55531SSebastian Grimberg   if (t_mode == CEED_TRANSPOSE) {
47c8a55531SSebastian Grimberg     t_stride_0 = 1;
48c8a55531SSebastian Grimberg     t_stride_1 = J;
49c8a55531SSebastian Grimberg   }
50c8a55531SSebastian Grimberg 
51c8a55531SSebastian Grimberg   for (CeedInt a = 0; a < A; a++) {
52c8a55531SSebastian Grimberg     // Blocks of 4 rows
53c8a55531SSebastian Grimberg     for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) {
54c8a55531SSebastian Grimberg       for (CeedInt c = 0; c < (C / CC) * CC; c += CC) {
55c8a55531SSebastian Grimberg         rtype vv[JJ][CC / 4];  // Output tile to be held in registers
56c8a55531SSebastian Grimberg         for (CeedInt jj = 0; jj < JJ; jj++) {
57c8a55531SSebastian Grimberg           for (CeedInt cc = 0; cc < CC / 4; cc++) vv[jj][cc] = loadu(&v[(a * J + j + jj) * C + c + cc * 4]);
58c8a55531SSebastian Grimberg         }
59c8a55531SSebastian Grimberg 
60c8a55531SSebastian Grimberg         for (CeedInt b = 0; b < B; b++) {
61c8a55531SSebastian Grimberg           for (CeedInt jj = 0; jj < JJ; jj++) {  // unroll
62c8a55531SSebastian Grimberg             rtype tqv = set1(t[(j + jj) * t_stride_0 + b * t_stride_1]);
63c8a55531SSebastian Grimberg             for (CeedInt cc = 0; cc < CC / 4; cc++) {  // unroll
64c8a55531SSebastian Grimberg               fmadd(vv[jj][cc], tqv, loadu(&u[(a * B + b) * C + c + cc * 4]));
65c8a55531SSebastian Grimberg             }
66c8a55531SSebastian Grimberg           }
67c8a55531SSebastian Grimberg         }
68c8a55531SSebastian Grimberg         for (CeedInt jj = 0; jj < JJ; jj++) {
69c8a55531SSebastian Grimberg           for (CeedInt cc = 0; cc < CC / 4; cc++) storeu(&v[(a * J + j + jj) * C + c + cc * 4], vv[jj][cc]);
70c8a55531SSebastian Grimberg         }
71c8a55531SSebastian Grimberg       }
72c8a55531SSebastian Grimberg     }
73c8a55531SSebastian Grimberg     // Remainder of rows
74c8a55531SSebastian Grimberg     CeedInt j = (J / JJ) * JJ;
75c8a55531SSebastian Grimberg     if (j < J) {
76c8a55531SSebastian Grimberg       for (CeedInt c = 0; c < (C / CC) * CC; c += CC) {
77c8a55531SSebastian Grimberg         rtype vv[JJ][CC / 4];  // Output tile to be held in registers
78c8a55531SSebastian Grimberg         for (CeedInt jj = 0; jj < J - j; jj++) {
79c8a55531SSebastian Grimberg           for (CeedInt cc = 0; cc < CC / 4; cc++) vv[jj][cc] = loadu(&v[(a * J + j + jj) * C + c + cc * 4]);
80c8a55531SSebastian Grimberg         }
81c8a55531SSebastian Grimberg 
82c8a55531SSebastian Grimberg         for (CeedInt b = 0; b < B; b++) {
83c8a55531SSebastian Grimberg           for (CeedInt jj = 0; jj < J - j; jj++) {  // doesn't unroll
84c8a55531SSebastian Grimberg             rtype tqv = set1(t[(j + jj) * t_stride_0 + b * t_stride_1]);
85c8a55531SSebastian Grimberg             for (CeedInt cc = 0; cc < CC / 4; cc++) {  // unroll
86c8a55531SSebastian Grimberg               fmadd(vv[jj][cc], tqv, loadu(&u[(a * B + b) * C + c + cc * 4]));
87c8a55531SSebastian Grimberg             }
88c8a55531SSebastian Grimberg           }
89c8a55531SSebastian Grimberg         }
90c8a55531SSebastian Grimberg         for (CeedInt jj = 0; jj < J - j; jj++) {
91c8a55531SSebastian Grimberg           for (CeedInt cc = 0; cc < CC / 4; cc++) storeu(&v[(a * J + j + jj) * C + c + cc * 4], vv[jj][cc]);
92c8a55531SSebastian Grimberg         }
93c8a55531SSebastian Grimberg       }
94c8a55531SSebastian Grimberg     }
95c8a55531SSebastian Grimberg   }
96c8a55531SSebastian Grimberg   return CEED_ERROR_SUCCESS;
97c8a55531SSebastian Grimberg }
98c8a55531SSebastian Grimberg 
99c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
100c8a55531SSebastian Grimberg // Serial Tensor Contract Remainder
101c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
102c8a55531SSebastian Grimberg static inline int CeedTensorContract_Avx_Remainder(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J,
103c8a55531SSebastian Grimberg                                                    const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add,
104c8a55531SSebastian Grimberg                                                    const CeedScalar *restrict u, CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
105c8a55531SSebastian Grimberg   CeedInt t_stride_0 = B, t_stride_1 = 1;
106c8a55531SSebastian Grimberg   if (t_mode == CEED_TRANSPOSE) {
107c8a55531SSebastian Grimberg     t_stride_0 = 1;
108c8a55531SSebastian Grimberg     t_stride_1 = J;
109c8a55531SSebastian Grimberg   }
110c8a55531SSebastian Grimberg 
111c8a55531SSebastian Grimberg   CeedInt J_break = J % JJ ? (J / JJ) * JJ : (J / JJ - 1) * JJ;
112c8a55531SSebastian Grimberg   for (CeedInt a = 0; a < A; a++) {
113c8a55531SSebastian Grimberg     // Blocks of 4 columns
114c8a55531SSebastian Grimberg     for (CeedInt c = (C / CC) * CC; c < C; c += 4) {
115c8a55531SSebastian Grimberg       // Blocks of 4 rows
116c8a55531SSebastian Grimberg       for (CeedInt j = 0; j < J_break; j += JJ) {
117c8a55531SSebastian Grimberg         rtype vv[JJ];  // Output tile to be held in registers
118c8a55531SSebastian Grimberg         for (CeedInt jj = 0; jj < JJ; jj++) vv[jj] = loadu(&v[(a * J + j + jj) * C + c]);
119c8a55531SSebastian Grimberg 
120c8a55531SSebastian Grimberg         for (CeedInt b = 0; b < B; b++) {
121c8a55531SSebastian Grimberg           rtype tqu;
122c8a55531SSebastian Grimberg           if (C - c == 1) tqu = set(0.0, 0.0, 0.0, u[(a * B + b) * C + c + 0]);
123c8a55531SSebastian Grimberg           else if (C - c == 2) tqu = set(0.0, 0.0, u[(a * B + b) * C + c + 1], u[(a * B + b) * C + c + 0]);
124c8a55531SSebastian Grimberg           else if (C - c == 3) tqu = set(0.0, u[(a * B + b) * C + c + 2], u[(a * B + b) * C + c + 1], u[(a * B + b) * C + c + 0]);
125c8a55531SSebastian Grimberg           else tqu = loadu(&u[(a * B + b) * C + c]);
126c8a55531SSebastian Grimberg           for (CeedInt jj = 0; jj < JJ; jj++) {  // unroll
127c8a55531SSebastian Grimberg             fmadd(vv[jj], tqu, set1(t[(j + jj) * t_stride_0 + b * t_stride_1]));
128c8a55531SSebastian Grimberg           }
129c8a55531SSebastian Grimberg         }
130c8a55531SSebastian Grimberg         for (CeedInt jj = 0; jj < JJ; jj++) storeu(&v[(a * J + j + jj) * C + c], vv[jj]);
131c8a55531SSebastian Grimberg       }
132c8a55531SSebastian Grimberg     }
133c8a55531SSebastian Grimberg     // Remainder of rows, all columns
134c8a55531SSebastian Grimberg     for (CeedInt j = J_break; j < J; j++) {
135c8a55531SSebastian Grimberg       for (CeedInt b = 0; b < B; b++) {
136c8a55531SSebastian Grimberg         CeedScalar tq = t[j * t_stride_0 + b * t_stride_1];
137c8a55531SSebastian Grimberg         for (CeedInt c = (C / CC) * CC; c < C; c++) v[(a * J + j) * C + c] += tq * u[(a * B + b) * C + c];
138c8a55531SSebastian Grimberg       }
139c8a55531SSebastian Grimberg     }
140c8a55531SSebastian Grimberg   }
141c8a55531SSebastian Grimberg   return CEED_ERROR_SUCCESS;
142c8a55531SSebastian Grimberg }
143c8a55531SSebastian Grimberg 
144c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
145c8a55531SSebastian Grimberg // Serial Tensor Contract C=1
146c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
147c8a55531SSebastian Grimberg static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
148c8a55531SSebastian Grimberg                                                 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v,
149c8a55531SSebastian Grimberg                                                 const CeedInt AA, const CeedInt JJ) {
150c8a55531SSebastian Grimberg   CeedInt t_stride_0 = B, t_stride_1 = 1;
151c8a55531SSebastian Grimberg   if (t_mode == CEED_TRANSPOSE) {
152c8a55531SSebastian Grimberg     t_stride_0 = 1;
153c8a55531SSebastian Grimberg     t_stride_1 = J;
154c8a55531SSebastian Grimberg   }
155c8a55531SSebastian Grimberg 
156c8a55531SSebastian Grimberg   // Blocks of 4 rows
157c8a55531SSebastian Grimberg   for (CeedInt a = 0; a < (A / AA) * AA; a += AA) {
158c8a55531SSebastian Grimberg     for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) {
159c8a55531SSebastian Grimberg       rtype vv[AA][JJ / 4];  // Output tile to be held in registers
160c8a55531SSebastian Grimberg       for (CeedInt aa = 0; aa < AA; aa++) {
161c8a55531SSebastian Grimberg         for (CeedInt jj = 0; jj < JJ / 4; jj++) vv[aa][jj] = loadu(&v[(a + aa) * J + j + jj * 4]);
162c8a55531SSebastian Grimberg       }
163c8a55531SSebastian Grimberg 
164c8a55531SSebastian Grimberg       for (CeedInt b = 0; b < B; b++) {
165c8a55531SSebastian Grimberg         for (CeedInt jj = 0; jj < JJ / 4; jj++) {  // unroll
166c8a55531SSebastian Grimberg           rtype tqv = set(t[(j + jj * 4 + 3) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 2) * t_stride_0 + b * t_stride_1],
167c8a55531SSebastian Grimberg                           t[(j + jj * 4 + 1) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 0) * t_stride_0 + b * t_stride_1]);
168c8a55531SSebastian Grimberg           for (CeedInt aa = 0; aa < AA; aa++) {  // unroll
169c8a55531SSebastian Grimberg             fmadd(vv[aa][jj], tqv, set1(u[(a + aa) * B + b]));
170c8a55531SSebastian Grimberg           }
171c8a55531SSebastian Grimberg         }
172c8a55531SSebastian Grimberg       }
173c8a55531SSebastian Grimberg       for (CeedInt aa = 0; aa < AA; aa++) {
174c8a55531SSebastian Grimberg         for (CeedInt jj = 0; jj < JJ / 4; jj++) storeu(&v[(a + aa) * J + j + jj * 4], vv[aa][jj]);
175c8a55531SSebastian Grimberg       }
176c8a55531SSebastian Grimberg     }
177c8a55531SSebastian Grimberg   }
178c8a55531SSebastian Grimberg   // Remainder of rows
179c8a55531SSebastian Grimberg   CeedInt a = (A / AA) * AA;
180c8a55531SSebastian Grimberg   for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) {
181c8a55531SSebastian Grimberg     rtype vv[AA][JJ / 4];  // Output tile to be held in registers
182c8a55531SSebastian Grimberg     for (CeedInt aa = 0; aa < A - a; aa++) {
183c8a55531SSebastian Grimberg       for (CeedInt jj = 0; jj < JJ / 4; jj++) vv[aa][jj] = loadu(&v[(a + aa) * J + j + jj * 4]);
184c8a55531SSebastian Grimberg     }
185c8a55531SSebastian Grimberg 
186c8a55531SSebastian Grimberg     for (CeedInt b = 0; b < B; b++) {
187c8a55531SSebastian Grimberg       for (CeedInt jj = 0; jj < JJ / 4; jj++) {  // unroll
188c8a55531SSebastian Grimberg         rtype tqv = set(t[(j + jj * 4 + 3) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 2) * t_stride_0 + b * t_stride_1],
189c8a55531SSebastian Grimberg                         t[(j + jj * 4 + 1) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 0) * t_stride_0 + b * t_stride_1]);
190c8a55531SSebastian Grimberg         for (CeedInt aa = 0; aa < A - a; aa++) {  // unroll
191c8a55531SSebastian Grimberg           fmadd(vv[aa][jj], tqv, set1(u[(a + aa) * B + b]));
192c8a55531SSebastian Grimberg         }
193c8a55531SSebastian Grimberg       }
194c8a55531SSebastian Grimberg     }
195c8a55531SSebastian Grimberg     for (CeedInt aa = 0; aa < A - a; aa++) {
196c8a55531SSebastian Grimberg       for (CeedInt jj = 0; jj < JJ / 4; jj++) storeu(&v[(a + aa) * J + j + jj * 4], vv[aa][jj]);
197c8a55531SSebastian Grimberg     }
198c8a55531SSebastian Grimberg   }
199c8a55531SSebastian Grimberg   // Column remainder
200c8a55531SSebastian Grimberg   CeedInt A_break = A % AA ? (A / AA) * AA : (A / AA - 1) * AA;
201c8a55531SSebastian Grimberg   // Blocks of 4 columns
202c8a55531SSebastian Grimberg   for (CeedInt j = (J / JJ) * JJ; j < J; j += 4) {
203c8a55531SSebastian Grimberg     // Blocks of 4 rows
204c8a55531SSebastian Grimberg     for (CeedInt a = 0; a < A_break; a += AA) {
205c8a55531SSebastian Grimberg       rtype vv[AA];  // Output tile to be held in registers
206c8a55531SSebastian Grimberg       for (CeedInt aa = 0; aa < AA; aa++) vv[aa] = loadu(&v[(a + aa) * J + j]);
207c8a55531SSebastian Grimberg 
208c8a55531SSebastian Grimberg       for (CeedInt b = 0; b < B; b++) {
209c8a55531SSebastian Grimberg         rtype tqv;
210c8a55531SSebastian Grimberg         if (J - j == 1) {
211c8a55531SSebastian Grimberg           tqv = set(0.0, 0.0, 0.0, t[(j + 0) * t_stride_0 + b * t_stride_1]);
212c8a55531SSebastian Grimberg         } else if (J - j == 2) {
213c8a55531SSebastian Grimberg           tqv = set(0.0, 0.0, t[(j + 1) * t_stride_0 + b * t_stride_1], t[(j + 0) * t_stride_0 + b * t_stride_1]);
214c8a55531SSebastian Grimberg         } else if (J - 3 == j) {
215c8a55531SSebastian Grimberg           tqv =
216c8a55531SSebastian Grimberg               set(0.0, t[(j + 2) * t_stride_0 + b * t_stride_1], t[(j + 1) * t_stride_0 + b * t_stride_1], t[(j + 0) * t_stride_0 + b * t_stride_1]);
217c8a55531SSebastian Grimberg         } else {
218c8a55531SSebastian Grimberg           tqv = set(t[(j + 3) * t_stride_0 + b * t_stride_1], t[(j + 2) * t_stride_0 + b * t_stride_1], t[(j + 1) * t_stride_0 + b * t_stride_1],
219c8a55531SSebastian Grimberg                     t[(j + 0) * t_stride_0 + b * t_stride_1]);
220c8a55531SSebastian Grimberg         }
221c8a55531SSebastian Grimberg         for (CeedInt aa = 0; aa < AA; aa++) {  // unroll
222c8a55531SSebastian Grimberg           fmadd(vv[aa], tqv, set1(u[(a + aa) * B + b]));
223c8a55531SSebastian Grimberg         }
224c8a55531SSebastian Grimberg       }
225c8a55531SSebastian Grimberg       for (CeedInt aa = 0; aa < AA; aa++) storeu(&v[(a + aa) * J + j], vv[aa]);
226c8a55531SSebastian Grimberg     }
227c8a55531SSebastian Grimberg   }
228c8a55531SSebastian Grimberg   // Remainder of rows, all columns
229c8a55531SSebastian Grimberg   for (CeedInt b = 0; b < B; b++) {
230c8a55531SSebastian Grimberg     for (CeedInt j = (J / JJ) * JJ; j < J; j++) {
231c8a55531SSebastian Grimberg       CeedScalar tq = t[j * t_stride_0 + b * t_stride_1];
232c8a55531SSebastian Grimberg       for (CeedInt a = A_break; a < A; a++) v[a * J + j] += tq * u[a * B + b];
233c8a55531SSebastian Grimberg     }
234c8a55531SSebastian Grimberg   }
235c8a55531SSebastian Grimberg   return CEED_ERROR_SUCCESS;
236c8a55531SSebastian Grimberg }
237c8a55531SSebastian Grimberg 
238c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
239c8a55531SSebastian Grimberg // Tensor Contract - Common Sizes
240c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
241c8a55531SSebastian Grimberg static int CeedTensorContract_Avx_Blocked_4_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
242c8a55531SSebastian Grimberg                                               CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
243c8a55531SSebastian Grimberg   return CeedTensorContract_Avx_Blocked(contract, A, B, C, J, t, t_mode, add, u, v, 4, 8);
244c8a55531SSebastian Grimberg }
245c8a55531SSebastian Grimberg static int CeedTensorContract_Avx_Remainder_8_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
246c8a55531SSebastian Grimberg                                                 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
247c8a55531SSebastian Grimberg   return CeedTensorContract_Avx_Remainder(contract, A, B, C, J, t, t_mode, add, u, v, 8, 8);
248c8a55531SSebastian Grimberg }
249c8a55531SSebastian Grimberg static int CeedTensorContract_Avx_Single_4_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
250c8a55531SSebastian Grimberg                                              CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
251c8a55531SSebastian Grimberg   return CeedTensorContract_Avx_Single(contract, A, B, C, J, t, t_mode, add, u, v, 4, 8);
252c8a55531SSebastian Grimberg }
253c8a55531SSebastian Grimberg 
254c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
255c8a55531SSebastian Grimberg // Tensor Contract Apply
256c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
257c8a55531SSebastian Grimberg static int CeedTensorContractApply_Avx(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
258c8a55531SSebastian Grimberg                                        CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
259c8a55531SSebastian Grimberg   const CeedInt blk_size = 8;
260c8a55531SSebastian Grimberg 
261c8a55531SSebastian Grimberg   if (!add) {
262c8a55531SSebastian Grimberg     for (CeedInt q = 0; q < A * J * C; q++) v[q] = (CeedScalar)0.0;
263c8a55531SSebastian Grimberg   }
264c8a55531SSebastian Grimberg 
265c8a55531SSebastian Grimberg   if (C == 1) {
266c8a55531SSebastian Grimberg     // Serial C=1 Case
267c8a55531SSebastian Grimberg     CeedTensorContract_Avx_Single_4_8(contract, A, B, C, J, t, t_mode, true, u, v);
268c8a55531SSebastian Grimberg   } else {
269c8a55531SSebastian Grimberg     // Blocks of 8 columns
270c8a55531SSebastian Grimberg     if (C >= blk_size) CeedTensorContract_Avx_Blocked_4_8(contract, A, B, C, J, t, t_mode, true, u, v);
271c8a55531SSebastian Grimberg     // Remainder of columns
272c8a55531SSebastian Grimberg     if (C % blk_size) CeedTensorContract_Avx_Remainder_8_8(contract, A, B, C, J, t, t_mode, true, u, v);
273c8a55531SSebastian Grimberg   }
274c8a55531SSebastian Grimberg 
275c8a55531SSebastian Grimberg   return CEED_ERROR_SUCCESS;
276c8a55531SSebastian Grimberg }
277c8a55531SSebastian Grimberg 
278c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
279c8a55531SSebastian Grimberg // Tensor Contract Create
280c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
281c8a55531SSebastian Grimberg int CeedTensorContractCreate_Avx(CeedBasis basis, CeedTensorContract contract) {
282c8a55531SSebastian Grimberg   Ceed ceed;
283c8a55531SSebastian Grimberg   CeedCallBackend(CeedTensorContractGetCeed(contract, &ceed));
284c8a55531SSebastian Grimberg 
285c8a55531SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", CeedTensorContractApply_Avx));
286c8a55531SSebastian Grimberg 
287c8a55531SSebastian Grimberg   return CEED_ERROR_SUCCESS;
288c8a55531SSebastian Grimberg }
289c8a55531SSebastian Grimberg 
290c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
291