xref: /libCEED/rust/libceed-sys/c-src/backends/avx/ceed-avx-tensor.c (revision c8a555312a5c51c13b0a524d101ec938ad17e6be)
1*c8a55531SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*c8a55531SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*c8a55531SSebastian Grimberg //
4*c8a55531SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5*c8a55531SSebastian Grimberg //
6*c8a55531SSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7*c8a55531SSebastian Grimberg 
8*c8a55531SSebastian Grimberg #include <ceed.h>
9*c8a55531SSebastian Grimberg #include <ceed/backend.h>
10*c8a55531SSebastian Grimberg #include <immintrin.h>
11*c8a55531SSebastian Grimberg #include <stdbool.h>
12*c8a55531SSebastian Grimberg 
13*c8a55531SSebastian Grimberg #ifdef _ceed_f64_h
14*c8a55531SSebastian Grimberg #define rtype __m256d
15*c8a55531SSebastian Grimberg #define loadu _mm256_loadu_pd
16*c8a55531SSebastian Grimberg #define storeu _mm256_storeu_pd
17*c8a55531SSebastian Grimberg #define set _mm256_set_pd
18*c8a55531SSebastian Grimberg #define set1 _mm256_set1_pd
19*c8a55531SSebastian Grimberg // c += a * b
20*c8a55531SSebastian Grimberg #ifdef __FMA__
21*c8a55531SSebastian Grimberg #define fmadd(c, a, b) (c) = _mm256_fmadd_pd((a), (b), (c))
22*c8a55531SSebastian Grimberg #else
23*c8a55531SSebastian Grimberg #define fmadd(c, a, b) (c) += _mm256_mul_pd((a), (b))
24*c8a55531SSebastian Grimberg #endif
25*c8a55531SSebastian Grimberg #else
26*c8a55531SSebastian Grimberg #define rtype __m128
27*c8a55531SSebastian Grimberg #define loadu _mm_loadu_ps
28*c8a55531SSebastian Grimberg #define storeu _mm_storeu_ps
29*c8a55531SSebastian Grimberg #define set _mm_set_ps
30*c8a55531SSebastian Grimberg #define set1 _mm_set1_ps
31*c8a55531SSebastian Grimberg // c += a * b
32*c8a55531SSebastian Grimberg #ifdef __FMA__
33*c8a55531SSebastian Grimberg #define fmadd(c, a, b) (c) = _mm_fmadd_ps((a), (b), (c))
34*c8a55531SSebastian Grimberg #else
35*c8a55531SSebastian Grimberg #define fmadd(c, a, b) (c) += _mm_mul_ps((a), (b))
36*c8a55531SSebastian Grimberg #endif
37*c8a55531SSebastian Grimberg #endif
38*c8a55531SSebastian Grimberg 
39*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
40*c8a55531SSebastian Grimberg // Blocked Tensor Contract
41*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
42*c8a55531SSebastian Grimberg static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J,
43*c8a55531SSebastian Grimberg                                                  const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add,
44*c8a55531SSebastian Grimberg                                                  const CeedScalar *restrict u, CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
45*c8a55531SSebastian Grimberg   CeedInt t_stride_0 = B, t_stride_1 = 1;
46*c8a55531SSebastian Grimberg   if (t_mode == CEED_TRANSPOSE) {
47*c8a55531SSebastian Grimberg     t_stride_0 = 1;
48*c8a55531SSebastian Grimberg     t_stride_1 = J;
49*c8a55531SSebastian Grimberg   }
50*c8a55531SSebastian Grimberg 
51*c8a55531SSebastian Grimberg   for (CeedInt a = 0; a < A; a++) {
52*c8a55531SSebastian Grimberg     // Blocks of 4 rows
53*c8a55531SSebastian Grimberg     for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) {
54*c8a55531SSebastian Grimberg       for (CeedInt c = 0; c < (C / CC) * CC; c += CC) {
55*c8a55531SSebastian Grimberg         rtype vv[JJ][CC / 4];  // Output tile to be held in registers
56*c8a55531SSebastian Grimberg         for (CeedInt jj = 0; jj < JJ; jj++) {
57*c8a55531SSebastian Grimberg           for (CeedInt cc = 0; cc < CC / 4; cc++) vv[jj][cc] = loadu(&v[(a * J + j + jj) * C + c + cc * 4]);
58*c8a55531SSebastian Grimberg         }
59*c8a55531SSebastian Grimberg 
60*c8a55531SSebastian Grimberg         for (CeedInt b = 0; b < B; b++) {
61*c8a55531SSebastian Grimberg           for (CeedInt jj = 0; jj < JJ; jj++) {  // unroll
62*c8a55531SSebastian Grimberg             rtype tqv = set1(t[(j + jj) * t_stride_0 + b * t_stride_1]);
63*c8a55531SSebastian Grimberg             for (CeedInt cc = 0; cc < CC / 4; cc++) {  // unroll
64*c8a55531SSebastian Grimberg               fmadd(vv[jj][cc], tqv, loadu(&u[(a * B + b) * C + c + cc * 4]));
65*c8a55531SSebastian Grimberg             }
66*c8a55531SSebastian Grimberg           }
67*c8a55531SSebastian Grimberg         }
68*c8a55531SSebastian Grimberg         for (CeedInt jj = 0; jj < JJ; jj++) {
69*c8a55531SSebastian Grimberg           for (CeedInt cc = 0; cc < CC / 4; cc++) storeu(&v[(a * J + j + jj) * C + c + cc * 4], vv[jj][cc]);
70*c8a55531SSebastian Grimberg         }
71*c8a55531SSebastian Grimberg       }
72*c8a55531SSebastian Grimberg     }
73*c8a55531SSebastian Grimberg     // Remainder of rows
74*c8a55531SSebastian Grimberg     CeedInt j = (J / JJ) * JJ;
75*c8a55531SSebastian Grimberg     if (j < J) {
76*c8a55531SSebastian Grimberg       for (CeedInt c = 0; c < (C / CC) * CC; c += CC) {
77*c8a55531SSebastian Grimberg         rtype vv[JJ][CC / 4];  // Output tile to be held in registers
78*c8a55531SSebastian Grimberg         for (CeedInt jj = 0; jj < J - j; jj++) {
79*c8a55531SSebastian Grimberg           for (CeedInt cc = 0; cc < CC / 4; cc++) vv[jj][cc] = loadu(&v[(a * J + j + jj) * C + c + cc * 4]);
80*c8a55531SSebastian Grimberg         }
81*c8a55531SSebastian Grimberg 
82*c8a55531SSebastian Grimberg         for (CeedInt b = 0; b < B; b++) {
83*c8a55531SSebastian Grimberg           for (CeedInt jj = 0; jj < J - j; jj++) {  // doesn't unroll
84*c8a55531SSebastian Grimberg             rtype tqv = set1(t[(j + jj) * t_stride_0 + b * t_stride_1]);
85*c8a55531SSebastian Grimberg             for (CeedInt cc = 0; cc < CC / 4; cc++) {  // unroll
86*c8a55531SSebastian Grimberg               fmadd(vv[jj][cc], tqv, loadu(&u[(a * B + b) * C + c + cc * 4]));
87*c8a55531SSebastian Grimberg             }
88*c8a55531SSebastian Grimberg           }
89*c8a55531SSebastian Grimberg         }
90*c8a55531SSebastian Grimberg         for (CeedInt jj = 0; jj < J - j; jj++) {
91*c8a55531SSebastian Grimberg           for (CeedInt cc = 0; cc < CC / 4; cc++) storeu(&v[(a * J + j + jj) * C + c + cc * 4], vv[jj][cc]);
92*c8a55531SSebastian Grimberg         }
93*c8a55531SSebastian Grimberg       }
94*c8a55531SSebastian Grimberg     }
95*c8a55531SSebastian Grimberg   }
96*c8a55531SSebastian Grimberg   return CEED_ERROR_SUCCESS;
97*c8a55531SSebastian Grimberg }
98*c8a55531SSebastian Grimberg 
99*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
100*c8a55531SSebastian Grimberg // Serial Tensor Contract Remainder
101*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
102*c8a55531SSebastian Grimberg static inline int CeedTensorContract_Avx_Remainder(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J,
103*c8a55531SSebastian Grimberg                                                    const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add,
104*c8a55531SSebastian Grimberg                                                    const CeedScalar *restrict u, CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
105*c8a55531SSebastian Grimberg   CeedInt t_stride_0 = B, t_stride_1 = 1;
106*c8a55531SSebastian Grimberg   if (t_mode == CEED_TRANSPOSE) {
107*c8a55531SSebastian Grimberg     t_stride_0 = 1;
108*c8a55531SSebastian Grimberg     t_stride_1 = J;
109*c8a55531SSebastian Grimberg   }
110*c8a55531SSebastian Grimberg 
111*c8a55531SSebastian Grimberg   CeedInt J_break = J % JJ ? (J / JJ) * JJ : (J / JJ - 1) * JJ;
112*c8a55531SSebastian Grimberg   for (CeedInt a = 0; a < A; a++) {
113*c8a55531SSebastian Grimberg     // Blocks of 4 columns
114*c8a55531SSebastian Grimberg     for (CeedInt c = (C / CC) * CC; c < C; c += 4) {
115*c8a55531SSebastian Grimberg       // Blocks of 4 rows
116*c8a55531SSebastian Grimberg       for (CeedInt j = 0; j < J_break; j += JJ) {
117*c8a55531SSebastian Grimberg         rtype vv[JJ];  // Output tile to be held in registers
118*c8a55531SSebastian Grimberg         for (CeedInt jj = 0; jj < JJ; jj++) vv[jj] = loadu(&v[(a * J + j + jj) * C + c]);
119*c8a55531SSebastian Grimberg 
120*c8a55531SSebastian Grimberg         for (CeedInt b = 0; b < B; b++) {
121*c8a55531SSebastian Grimberg           rtype tqu;
122*c8a55531SSebastian Grimberg           if (C - c == 1) tqu = set(0.0, 0.0, 0.0, u[(a * B + b) * C + c + 0]);
123*c8a55531SSebastian Grimberg           else if (C - c == 2) tqu = set(0.0, 0.0, u[(a * B + b) * C + c + 1], u[(a * B + b) * C + c + 0]);
124*c8a55531SSebastian Grimberg           else if (C - c == 3) tqu = set(0.0, u[(a * B + b) * C + c + 2], u[(a * B + b) * C + c + 1], u[(a * B + b) * C + c + 0]);
125*c8a55531SSebastian Grimberg           else tqu = loadu(&u[(a * B + b) * C + c]);
126*c8a55531SSebastian Grimberg           for (CeedInt jj = 0; jj < JJ; jj++) {  // unroll
127*c8a55531SSebastian Grimberg             fmadd(vv[jj], tqu, set1(t[(j + jj) * t_stride_0 + b * t_stride_1]));
128*c8a55531SSebastian Grimberg           }
129*c8a55531SSebastian Grimberg         }
130*c8a55531SSebastian Grimberg         for (CeedInt jj = 0; jj < JJ; jj++) storeu(&v[(a * J + j + jj) * C + c], vv[jj]);
131*c8a55531SSebastian Grimberg       }
132*c8a55531SSebastian Grimberg     }
133*c8a55531SSebastian Grimberg     // Remainder of rows, all columns
134*c8a55531SSebastian Grimberg     for (CeedInt j = J_break; j < J; j++) {
135*c8a55531SSebastian Grimberg       for (CeedInt b = 0; b < B; b++) {
136*c8a55531SSebastian Grimberg         CeedScalar tq = t[j * t_stride_0 + b * t_stride_1];
137*c8a55531SSebastian Grimberg         for (CeedInt c = (C / CC) * CC; c < C; c++) v[(a * J + j) * C + c] += tq * u[(a * B + b) * C + c];
138*c8a55531SSebastian Grimberg       }
139*c8a55531SSebastian Grimberg     }
140*c8a55531SSebastian Grimberg   }
141*c8a55531SSebastian Grimberg   return CEED_ERROR_SUCCESS;
142*c8a55531SSebastian Grimberg }
143*c8a55531SSebastian Grimberg 
144*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
145*c8a55531SSebastian Grimberg // Serial Tensor Contract C=1
146*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
147*c8a55531SSebastian Grimberg static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
148*c8a55531SSebastian Grimberg                                                 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v,
149*c8a55531SSebastian Grimberg                                                 const CeedInt AA, const CeedInt JJ) {
150*c8a55531SSebastian Grimberg   CeedInt t_stride_0 = B, t_stride_1 = 1;
151*c8a55531SSebastian Grimberg   if (t_mode == CEED_TRANSPOSE) {
152*c8a55531SSebastian Grimberg     t_stride_0 = 1;
153*c8a55531SSebastian Grimberg     t_stride_1 = J;
154*c8a55531SSebastian Grimberg   }
155*c8a55531SSebastian Grimberg 
156*c8a55531SSebastian Grimberg   // Blocks of 4 rows
157*c8a55531SSebastian Grimberg   for (CeedInt a = 0; a < (A / AA) * AA; a += AA) {
158*c8a55531SSebastian Grimberg     for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) {
159*c8a55531SSebastian Grimberg       rtype vv[AA][JJ / 4];  // Output tile to be held in registers
160*c8a55531SSebastian Grimberg       for (CeedInt aa = 0; aa < AA; aa++) {
161*c8a55531SSebastian Grimberg         for (CeedInt jj = 0; jj < JJ / 4; jj++) vv[aa][jj] = loadu(&v[(a + aa) * J + j + jj * 4]);
162*c8a55531SSebastian Grimberg       }
163*c8a55531SSebastian Grimberg 
164*c8a55531SSebastian Grimberg       for (CeedInt b = 0; b < B; b++) {
165*c8a55531SSebastian Grimberg         for (CeedInt jj = 0; jj < JJ / 4; jj++) {  // unroll
166*c8a55531SSebastian Grimberg           rtype tqv = set(t[(j + jj * 4 + 3) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 2) * t_stride_0 + b * t_stride_1],
167*c8a55531SSebastian Grimberg                           t[(j + jj * 4 + 1) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 0) * t_stride_0 + b * t_stride_1]);
168*c8a55531SSebastian Grimberg           for (CeedInt aa = 0; aa < AA; aa++) {  // unroll
169*c8a55531SSebastian Grimberg             fmadd(vv[aa][jj], tqv, set1(u[(a + aa) * B + b]));
170*c8a55531SSebastian Grimberg           }
171*c8a55531SSebastian Grimberg         }
172*c8a55531SSebastian Grimberg       }
173*c8a55531SSebastian Grimberg       for (CeedInt aa = 0; aa < AA; aa++) {
174*c8a55531SSebastian Grimberg         for (CeedInt jj = 0; jj < JJ / 4; jj++) storeu(&v[(a + aa) * J + j + jj * 4], vv[aa][jj]);
175*c8a55531SSebastian Grimberg       }
176*c8a55531SSebastian Grimberg     }
177*c8a55531SSebastian Grimberg   }
178*c8a55531SSebastian Grimberg   // Remainder of rows
179*c8a55531SSebastian Grimberg   CeedInt a = (A / AA) * AA;
180*c8a55531SSebastian Grimberg   for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) {
181*c8a55531SSebastian Grimberg     rtype vv[AA][JJ / 4];  // Output tile to be held in registers
182*c8a55531SSebastian Grimberg     for (CeedInt aa = 0; aa < A - a; aa++) {
183*c8a55531SSebastian Grimberg       for (CeedInt jj = 0; jj < JJ / 4; jj++) vv[aa][jj] = loadu(&v[(a + aa) * J + j + jj * 4]);
184*c8a55531SSebastian Grimberg     }
185*c8a55531SSebastian Grimberg 
186*c8a55531SSebastian Grimberg     for (CeedInt b = 0; b < B; b++) {
187*c8a55531SSebastian Grimberg       for (CeedInt jj = 0; jj < JJ / 4; jj++) {  // unroll
188*c8a55531SSebastian Grimberg         rtype tqv = set(t[(j + jj * 4 + 3) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 2) * t_stride_0 + b * t_stride_1],
189*c8a55531SSebastian Grimberg                         t[(j + jj * 4 + 1) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 0) * t_stride_0 + b * t_stride_1]);
190*c8a55531SSebastian Grimberg         for (CeedInt aa = 0; aa < A - a; aa++) {  // unroll
191*c8a55531SSebastian Grimberg           fmadd(vv[aa][jj], tqv, set1(u[(a + aa) * B + b]));
192*c8a55531SSebastian Grimberg         }
193*c8a55531SSebastian Grimberg       }
194*c8a55531SSebastian Grimberg     }
195*c8a55531SSebastian Grimberg     for (CeedInt aa = 0; aa < A - a; aa++) {
196*c8a55531SSebastian Grimberg       for (CeedInt jj = 0; jj < JJ / 4; jj++) storeu(&v[(a + aa) * J + j + jj * 4], vv[aa][jj]);
197*c8a55531SSebastian Grimberg     }
198*c8a55531SSebastian Grimberg   }
199*c8a55531SSebastian Grimberg   // Column remainder
200*c8a55531SSebastian Grimberg   CeedInt A_break = A % AA ? (A / AA) * AA : (A / AA - 1) * AA;
201*c8a55531SSebastian Grimberg   // Blocks of 4 columns
202*c8a55531SSebastian Grimberg   for (CeedInt j = (J / JJ) * JJ; j < J; j += 4) {
203*c8a55531SSebastian Grimberg     // Blocks of 4 rows
204*c8a55531SSebastian Grimberg     for (CeedInt a = 0; a < A_break; a += AA) {
205*c8a55531SSebastian Grimberg       rtype vv[AA];  // Output tile to be held in registers
206*c8a55531SSebastian Grimberg       for (CeedInt aa = 0; aa < AA; aa++) vv[aa] = loadu(&v[(a + aa) * J + j]);
207*c8a55531SSebastian Grimberg 
208*c8a55531SSebastian Grimberg       for (CeedInt b = 0; b < B; b++) {
209*c8a55531SSebastian Grimberg         rtype tqv;
210*c8a55531SSebastian Grimberg         if (J - j == 1) {
211*c8a55531SSebastian Grimberg           tqv = set(0.0, 0.0, 0.0, t[(j + 0) * t_stride_0 + b * t_stride_1]);
212*c8a55531SSebastian Grimberg         } else if (J - j == 2) {
213*c8a55531SSebastian Grimberg           tqv = set(0.0, 0.0, t[(j + 1) * t_stride_0 + b * t_stride_1], t[(j + 0) * t_stride_0 + b * t_stride_1]);
214*c8a55531SSebastian Grimberg         } else if (J - 3 == j) {
215*c8a55531SSebastian Grimberg           tqv =
216*c8a55531SSebastian Grimberg               set(0.0, t[(j + 2) * t_stride_0 + b * t_stride_1], t[(j + 1) * t_stride_0 + b * t_stride_1], t[(j + 0) * t_stride_0 + b * t_stride_1]);
217*c8a55531SSebastian Grimberg         } else {
218*c8a55531SSebastian Grimberg           tqv = set(t[(j + 3) * t_stride_0 + b * t_stride_1], t[(j + 2) * t_stride_0 + b * t_stride_1], t[(j + 1) * t_stride_0 + b * t_stride_1],
219*c8a55531SSebastian Grimberg                     t[(j + 0) * t_stride_0 + b * t_stride_1]);
220*c8a55531SSebastian Grimberg         }
221*c8a55531SSebastian Grimberg         for (CeedInt aa = 0; aa < AA; aa++) {  // unroll
222*c8a55531SSebastian Grimberg           fmadd(vv[aa], tqv, set1(u[(a + aa) * B + b]));
223*c8a55531SSebastian Grimberg         }
224*c8a55531SSebastian Grimberg       }
225*c8a55531SSebastian Grimberg       for (CeedInt aa = 0; aa < AA; aa++) storeu(&v[(a + aa) * J + j], vv[aa]);
226*c8a55531SSebastian Grimberg     }
227*c8a55531SSebastian Grimberg   }
228*c8a55531SSebastian Grimberg   // Remainder of rows, all columns
229*c8a55531SSebastian Grimberg   for (CeedInt b = 0; b < B; b++) {
230*c8a55531SSebastian Grimberg     for (CeedInt j = (J / JJ) * JJ; j < J; j++) {
231*c8a55531SSebastian Grimberg       CeedScalar tq = t[j * t_stride_0 + b * t_stride_1];
232*c8a55531SSebastian Grimberg       for (CeedInt a = A_break; a < A; a++) v[a * J + j] += tq * u[a * B + b];
233*c8a55531SSebastian Grimberg     }
234*c8a55531SSebastian Grimberg   }
235*c8a55531SSebastian Grimberg   return CEED_ERROR_SUCCESS;
236*c8a55531SSebastian Grimberg }
237*c8a55531SSebastian Grimberg 
238*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
239*c8a55531SSebastian Grimberg // Tensor Contract - Common Sizes
240*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
241*c8a55531SSebastian Grimberg static int CeedTensorContract_Avx_Blocked_4_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
242*c8a55531SSebastian Grimberg                                               CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
243*c8a55531SSebastian Grimberg   return CeedTensorContract_Avx_Blocked(contract, A, B, C, J, t, t_mode, add, u, v, 4, 8);
244*c8a55531SSebastian Grimberg }
245*c8a55531SSebastian Grimberg static int CeedTensorContract_Avx_Remainder_8_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
246*c8a55531SSebastian Grimberg                                                 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
247*c8a55531SSebastian Grimberg   return CeedTensorContract_Avx_Remainder(contract, A, B, C, J, t, t_mode, add, u, v, 8, 8);
248*c8a55531SSebastian Grimberg }
249*c8a55531SSebastian Grimberg static int CeedTensorContract_Avx_Single_4_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
250*c8a55531SSebastian Grimberg                                              CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
251*c8a55531SSebastian Grimberg   return CeedTensorContract_Avx_Single(contract, A, B, C, J, t, t_mode, add, u, v, 4, 8);
252*c8a55531SSebastian Grimberg }
253*c8a55531SSebastian Grimberg 
254*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
255*c8a55531SSebastian Grimberg // Tensor Contract Apply
256*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
257*c8a55531SSebastian Grimberg static int CeedTensorContractApply_Avx(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
258*c8a55531SSebastian Grimberg                                        CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
259*c8a55531SSebastian Grimberg   const CeedInt blk_size = 8;
260*c8a55531SSebastian Grimberg 
261*c8a55531SSebastian Grimberg   if (!add) {
262*c8a55531SSebastian Grimberg     for (CeedInt q = 0; q < A * J * C; q++) v[q] = (CeedScalar)0.0;
263*c8a55531SSebastian Grimberg   }
264*c8a55531SSebastian Grimberg 
265*c8a55531SSebastian Grimberg   if (C == 1) {
266*c8a55531SSebastian Grimberg     // Serial C=1 Case
267*c8a55531SSebastian Grimberg     CeedTensorContract_Avx_Single_4_8(contract, A, B, C, J, t, t_mode, true, u, v);
268*c8a55531SSebastian Grimberg   } else {
269*c8a55531SSebastian Grimberg     // Blocks of 8 columns
270*c8a55531SSebastian Grimberg     if (C >= blk_size) CeedTensorContract_Avx_Blocked_4_8(contract, A, B, C, J, t, t_mode, true, u, v);
271*c8a55531SSebastian Grimberg     // Remainder of columns
272*c8a55531SSebastian Grimberg     if (C % blk_size) CeedTensorContract_Avx_Remainder_8_8(contract, A, B, C, J, t, t_mode, true, u, v);
273*c8a55531SSebastian Grimberg   }
274*c8a55531SSebastian Grimberg 
275*c8a55531SSebastian Grimberg   return CEED_ERROR_SUCCESS;
276*c8a55531SSebastian Grimberg }
277*c8a55531SSebastian Grimberg 
278*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
279*c8a55531SSebastian Grimberg // Tensor Contract Create
280*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
281*c8a55531SSebastian Grimberg int CeedTensorContractCreate_Avx(CeedBasis basis, CeedTensorContract contract) {
282*c8a55531SSebastian Grimberg   Ceed ceed;
283*c8a55531SSebastian Grimberg   CeedCallBackend(CeedTensorContractGetCeed(contract, &ceed));
284*c8a55531SSebastian Grimberg 
285*c8a55531SSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", CeedTensorContractApply_Avx));
286*c8a55531SSebastian Grimberg 
287*c8a55531SSebastian Grimberg   return CEED_ERROR_SUCCESS;
288*c8a55531SSebastian Grimberg }
289*c8a55531SSebastian Grimberg 
290*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
291