1*9ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2c8a55531SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3c8a55531SSebastian Grimberg //
4c8a55531SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5c8a55531SSebastian Grimberg //
6c8a55531SSebastian Grimberg // This file is part of CEED: http://github.com/ceed
7c8a55531SSebastian Grimberg
8c8a55531SSebastian Grimberg #include <ceed.h>
9c8a55531SSebastian Grimberg #include <ceed/backend.h>
10c8a55531SSebastian Grimberg #include <immintrin.h>
11c8a55531SSebastian Grimberg #include <stdbool.h>
12c8a55531SSebastian Grimberg
136a96780fSJeremy L Thompson #ifdef CEED_SCALAR_IS_FP64
14c8a55531SSebastian Grimberg #define rtype __m256d
15c8a55531SSebastian Grimberg #define loadu _mm256_loadu_pd
16c8a55531SSebastian Grimberg #define storeu _mm256_storeu_pd
17c8a55531SSebastian Grimberg #define set _mm256_set_pd
18c8a55531SSebastian Grimberg #define set1 _mm256_set1_pd
19c8a55531SSebastian Grimberg // c += a * b
20c8a55531SSebastian Grimberg #ifdef __FMA__
21c8a55531SSebastian Grimberg #define fmadd(c, a, b) (c) = _mm256_fmadd_pd((a), (b), (c))
22c8a55531SSebastian Grimberg #else
23c8a55531SSebastian Grimberg #define fmadd(c, a, b) (c) += _mm256_mul_pd((a), (b))
24c8a55531SSebastian Grimberg #endif
25c8a55531SSebastian Grimberg #else
26c8a55531SSebastian Grimberg #define rtype __m128
27c8a55531SSebastian Grimberg #define loadu _mm_loadu_ps
28c8a55531SSebastian Grimberg #define storeu _mm_storeu_ps
29c8a55531SSebastian Grimberg #define set _mm_set_ps
30c8a55531SSebastian Grimberg #define set1 _mm_set1_ps
31c8a55531SSebastian Grimberg // c += a * b
32c8a55531SSebastian Grimberg #ifdef __FMA__
33c8a55531SSebastian Grimberg #define fmadd(c, a, b) (c) = _mm_fmadd_ps((a), (b), (c))
34c8a55531SSebastian Grimberg #else
35c8a55531SSebastian Grimberg #define fmadd(c, a, b) (c) += _mm_mul_ps((a), (b))
36c8a55531SSebastian Grimberg #endif
37c8a55531SSebastian Grimberg #endif
38c8a55531SSebastian Grimberg
39c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
40c8a55531SSebastian Grimberg // Blocked Tensor Contract
41c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
CeedTensorContract_Avx_Blocked(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v,const CeedInt JJ,const CeedInt CC)42c8a55531SSebastian Grimberg static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J,
43c8a55531SSebastian Grimberg const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add,
44c8a55531SSebastian Grimberg const CeedScalar *restrict u, CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
45c8a55531SSebastian Grimberg CeedInt t_stride_0 = B, t_stride_1 = 1;
46ad70ee2cSJeremy L Thompson
47c8a55531SSebastian Grimberg if (t_mode == CEED_TRANSPOSE) {
48c8a55531SSebastian Grimberg t_stride_0 = 1;
49c8a55531SSebastian Grimberg t_stride_1 = J;
50c8a55531SSebastian Grimberg }
51c8a55531SSebastian Grimberg
52c8a55531SSebastian Grimberg for (CeedInt a = 0; a < A; a++) {
53c8a55531SSebastian Grimberg // Blocks of 4 rows
54c8a55531SSebastian Grimberg for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) {
55c8a55531SSebastian Grimberg for (CeedInt c = 0; c < (C / CC) * CC; c += CC) {
56c8a55531SSebastian Grimberg rtype vv[JJ][CC / 4]; // Output tile to be held in registers
57c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ; jj++) {
58c8a55531SSebastian Grimberg for (CeedInt cc = 0; cc < CC / 4; cc++) vv[jj][cc] = loadu(&v[(a * J + j + jj) * C + c + cc * 4]);
59c8a55531SSebastian Grimberg }
60c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) {
61c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ; jj++) { // unroll
62c8a55531SSebastian Grimberg rtype tqv = set1(t[(j + jj) * t_stride_0 + b * t_stride_1]);
63c8a55531SSebastian Grimberg for (CeedInt cc = 0; cc < CC / 4; cc++) { // unroll
64c8a55531SSebastian Grimberg fmadd(vv[jj][cc], tqv, loadu(&u[(a * B + b) * C + c + cc * 4]));
65c8a55531SSebastian Grimberg }
66c8a55531SSebastian Grimberg }
67c8a55531SSebastian Grimberg }
68c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ; jj++) {
69c8a55531SSebastian Grimberg for (CeedInt cc = 0; cc < CC / 4; cc++) storeu(&v[(a * J + j + jj) * C + c + cc * 4], vv[jj][cc]);
70c8a55531SSebastian Grimberg }
71c8a55531SSebastian Grimberg }
72c8a55531SSebastian Grimberg }
73c8a55531SSebastian Grimberg // Remainder of rows
74ad70ee2cSJeremy L Thompson const CeedInt j = (J / JJ) * JJ;
75ad70ee2cSJeremy L Thompson
76c8a55531SSebastian Grimberg if (j < J) {
77c8a55531SSebastian Grimberg for (CeedInt c = 0; c < (C / CC) * CC; c += CC) {
78c8a55531SSebastian Grimberg rtype vv[JJ][CC / 4]; // Output tile to be held in registers
79ad70ee2cSJeremy L Thompson
80c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < J - j; jj++) {
81c8a55531SSebastian Grimberg for (CeedInt cc = 0; cc < CC / 4; cc++) vv[jj][cc] = loadu(&v[(a * J + j + jj) * C + c + cc * 4]);
82c8a55531SSebastian Grimberg }
83c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) {
84c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < J - j; jj++) { // doesn't unroll
85c8a55531SSebastian Grimberg rtype tqv = set1(t[(j + jj) * t_stride_0 + b * t_stride_1]);
86ad70ee2cSJeremy L Thompson
87c8a55531SSebastian Grimberg for (CeedInt cc = 0; cc < CC / 4; cc++) { // unroll
88c8a55531SSebastian Grimberg fmadd(vv[jj][cc], tqv, loadu(&u[(a * B + b) * C + c + cc * 4]));
89c8a55531SSebastian Grimberg }
90c8a55531SSebastian Grimberg }
91c8a55531SSebastian Grimberg }
92c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < J - j; jj++) {
93c8a55531SSebastian Grimberg for (CeedInt cc = 0; cc < CC / 4; cc++) storeu(&v[(a * J + j + jj) * C + c + cc * 4], vv[jj][cc]);
94c8a55531SSebastian Grimberg }
95c8a55531SSebastian Grimberg }
96c8a55531SSebastian Grimberg }
97c8a55531SSebastian Grimberg }
98c8a55531SSebastian Grimberg return CEED_ERROR_SUCCESS;
99c8a55531SSebastian Grimberg }
100c8a55531SSebastian Grimberg
101c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
102c8a55531SSebastian Grimberg // Serial Tensor Contract Remainder
103c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
CeedTensorContract_Avx_Remainder(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v,const CeedInt JJ,const CeedInt CC)104c8a55531SSebastian Grimberg static inline int CeedTensorContract_Avx_Remainder(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J,
105c8a55531SSebastian Grimberg const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add,
106c8a55531SSebastian Grimberg const CeedScalar *restrict u, CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
107c8a55531SSebastian Grimberg CeedInt t_stride_0 = B, t_stride_1 = 1;
108ad70ee2cSJeremy L Thompson
109c8a55531SSebastian Grimberg if (t_mode == CEED_TRANSPOSE) {
110c8a55531SSebastian Grimberg t_stride_0 = 1;
111c8a55531SSebastian Grimberg t_stride_1 = J;
112c8a55531SSebastian Grimberg }
113c8a55531SSebastian Grimberg
114ad70ee2cSJeremy L Thompson const CeedInt J_break = J % JJ ? (J / JJ) * JJ : (J / JJ - 1) * JJ;
115ad70ee2cSJeremy L Thompson
116c8a55531SSebastian Grimberg for (CeedInt a = 0; a < A; a++) {
117c8a55531SSebastian Grimberg // Blocks of 4 columns
118c8a55531SSebastian Grimberg for (CeedInt c = (C / CC) * CC; c < C; c += 4) {
119c8a55531SSebastian Grimberg // Blocks of 4 rows
120c8a55531SSebastian Grimberg for (CeedInt j = 0; j < J_break; j += JJ) {
121c8a55531SSebastian Grimberg rtype vv[JJ]; // Output tile to be held in registers
122c8a55531SSebastian Grimberg
123ad70ee2cSJeremy L Thompson for (CeedInt jj = 0; jj < JJ; jj++) vv[jj] = loadu(&v[(a * J + j + jj) * C + c]);
124c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) {
125c8a55531SSebastian Grimberg rtype tqu;
126ad70ee2cSJeremy L Thompson
127c8a55531SSebastian Grimberg if (C - c == 1) tqu = set(0.0, 0.0, 0.0, u[(a * B + b) * C + c + 0]);
128c8a55531SSebastian Grimberg else if (C - c == 2) tqu = set(0.0, 0.0, u[(a * B + b) * C + c + 1], u[(a * B + b) * C + c + 0]);
129c8a55531SSebastian Grimberg else if (C - c == 3) tqu = set(0.0, u[(a * B + b) * C + c + 2], u[(a * B + b) * C + c + 1], u[(a * B + b) * C + c + 0]);
130c8a55531SSebastian Grimberg else tqu = loadu(&u[(a * B + b) * C + c]);
131c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ; jj++) { // unroll
132c8a55531SSebastian Grimberg fmadd(vv[jj], tqu, set1(t[(j + jj) * t_stride_0 + b * t_stride_1]));
133c8a55531SSebastian Grimberg }
134c8a55531SSebastian Grimberg }
135c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ; jj++) storeu(&v[(a * J + j + jj) * C + c], vv[jj]);
136c8a55531SSebastian Grimberg }
137c8a55531SSebastian Grimberg }
138c8a55531SSebastian Grimberg // Remainder of rows, all columns
139c8a55531SSebastian Grimberg for (CeedInt j = J_break; j < J; j++) {
140c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) {
141ad70ee2cSJeremy L Thompson const CeedScalar tq = t[j * t_stride_0 + b * t_stride_1];
142ad70ee2cSJeremy L Thompson
143c8a55531SSebastian Grimberg for (CeedInt c = (C / CC) * CC; c < C; c++) v[(a * J + j) * C + c] += tq * u[(a * B + b) * C + c];
144c8a55531SSebastian Grimberg }
145c8a55531SSebastian Grimberg }
146c8a55531SSebastian Grimberg }
147c8a55531SSebastian Grimberg return CEED_ERROR_SUCCESS;
148c8a55531SSebastian Grimberg }
149c8a55531SSebastian Grimberg
150c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
151c8a55531SSebastian Grimberg // Serial Tensor Contract C=1
152c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
CeedTensorContract_Avx_Single(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v,const CeedInt AA,const CeedInt JJ)153c8a55531SSebastian Grimberg static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
154c8a55531SSebastian Grimberg CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v,
155c8a55531SSebastian Grimberg const CeedInt AA, const CeedInt JJ) {
156c8a55531SSebastian Grimberg CeedInt t_stride_0 = B, t_stride_1 = 1;
157ad70ee2cSJeremy L Thompson
158c8a55531SSebastian Grimberg if (t_mode == CEED_TRANSPOSE) {
159c8a55531SSebastian Grimberg t_stride_0 = 1;
160c8a55531SSebastian Grimberg t_stride_1 = J;
161c8a55531SSebastian Grimberg }
162c8a55531SSebastian Grimberg
163c8a55531SSebastian Grimberg // Blocks of 4 rows
164c8a55531SSebastian Grimberg for (CeedInt a = 0; a < (A / AA) * AA; a += AA) {
165c8a55531SSebastian Grimberg for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) {
166c8a55531SSebastian Grimberg rtype vv[AA][JJ / 4]; // Output tile to be held in registers
167ad70ee2cSJeremy L Thompson
168c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < AA; aa++) {
169c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ / 4; jj++) vv[aa][jj] = loadu(&v[(a + aa) * J + j + jj * 4]);
170c8a55531SSebastian Grimberg }
171c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) {
172c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ / 4; jj++) { // unroll
173c8a55531SSebastian Grimberg rtype tqv = set(t[(j + jj * 4 + 3) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 2) * t_stride_0 + b * t_stride_1],
174c8a55531SSebastian Grimberg t[(j + jj * 4 + 1) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 0) * t_stride_0 + b * t_stride_1]);
175ad70ee2cSJeremy L Thompson
176c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < AA; aa++) { // unroll
177c8a55531SSebastian Grimberg fmadd(vv[aa][jj], tqv, set1(u[(a + aa) * B + b]));
178c8a55531SSebastian Grimberg }
179c8a55531SSebastian Grimberg }
180c8a55531SSebastian Grimberg }
181c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < AA; aa++) {
182c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ / 4; jj++) storeu(&v[(a + aa) * J + j + jj * 4], vv[aa][jj]);
183c8a55531SSebastian Grimberg }
184c8a55531SSebastian Grimberg }
185c8a55531SSebastian Grimberg }
186c8a55531SSebastian Grimberg // Remainder of rows
187ad70ee2cSJeremy L Thompson const CeedInt a = (A / AA) * AA;
188ad70ee2cSJeremy L Thompson
189c8a55531SSebastian Grimberg for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) {
190c8a55531SSebastian Grimberg rtype vv[AA][JJ / 4]; // Output tile to be held in registers
191ad70ee2cSJeremy L Thompson
192c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < A - a; aa++) {
193c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ / 4; jj++) vv[aa][jj] = loadu(&v[(a + aa) * J + j + jj * 4]);
194c8a55531SSebastian Grimberg }
195c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) {
196c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ / 4; jj++) { // unroll
197c8a55531SSebastian Grimberg rtype tqv = set(t[(j + jj * 4 + 3) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 2) * t_stride_0 + b * t_stride_1],
198c8a55531SSebastian Grimberg t[(j + jj * 4 + 1) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 0) * t_stride_0 + b * t_stride_1]);
199ad70ee2cSJeremy L Thompson
200c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < A - a; aa++) { // unroll
201c8a55531SSebastian Grimberg fmadd(vv[aa][jj], tqv, set1(u[(a + aa) * B + b]));
202c8a55531SSebastian Grimberg }
203c8a55531SSebastian Grimberg }
204c8a55531SSebastian Grimberg }
205c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < A - a; aa++) {
206c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ / 4; jj++) storeu(&v[(a + aa) * J + j + jj * 4], vv[aa][jj]);
207c8a55531SSebastian Grimberg }
208c8a55531SSebastian Grimberg }
209c8a55531SSebastian Grimberg // Column remainder
210ad70ee2cSJeremy L Thompson const CeedInt A_break = A % AA ? (A / AA) * AA : (A / AA - 1) * AA;
211ad70ee2cSJeremy L Thompson
212c8a55531SSebastian Grimberg // Blocks of 4 columns
213c8a55531SSebastian Grimberg for (CeedInt j = (J / JJ) * JJ; j < J; j += 4) {
214c8a55531SSebastian Grimberg // Blocks of 4 rows
215c8a55531SSebastian Grimberg for (CeedInt a = 0; a < A_break; a += AA) {
216c8a55531SSebastian Grimberg rtype vv[AA]; // Output tile to be held in registers
217c8a55531SSebastian Grimberg
218ad70ee2cSJeremy L Thompson for (CeedInt aa = 0; aa < AA; aa++) vv[aa] = loadu(&v[(a + aa) * J + j]);
219c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) {
220c8a55531SSebastian Grimberg rtype tqv;
221ad70ee2cSJeremy L Thompson
222c8a55531SSebastian Grimberg if (J - j == 1) {
223c8a55531SSebastian Grimberg tqv = set(0.0, 0.0, 0.0, t[(j + 0) * t_stride_0 + b * t_stride_1]);
224c8a55531SSebastian Grimberg } else if (J - j == 2) {
225c8a55531SSebastian Grimberg tqv = set(0.0, 0.0, t[(j + 1) * t_stride_0 + b * t_stride_1], t[(j + 0) * t_stride_0 + b * t_stride_1]);
226c8a55531SSebastian Grimberg } else if (J - 3 == j) {
227c8a55531SSebastian Grimberg tqv =
228c8a55531SSebastian Grimberg set(0.0, t[(j + 2) * t_stride_0 + b * t_stride_1], t[(j + 1) * t_stride_0 + b * t_stride_1], t[(j + 0) * t_stride_0 + b * t_stride_1]);
229c8a55531SSebastian Grimberg } else {
230c8a55531SSebastian Grimberg tqv = set(t[(j + 3) * t_stride_0 + b * t_stride_1], t[(j + 2) * t_stride_0 + b * t_stride_1], t[(j + 1) * t_stride_0 + b * t_stride_1],
231c8a55531SSebastian Grimberg t[(j + 0) * t_stride_0 + b * t_stride_1]);
232c8a55531SSebastian Grimberg }
233c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < AA; aa++) { // unroll
234c8a55531SSebastian Grimberg fmadd(vv[aa], tqv, set1(u[(a + aa) * B + b]));
235c8a55531SSebastian Grimberg }
236c8a55531SSebastian Grimberg }
237c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < AA; aa++) storeu(&v[(a + aa) * J + j], vv[aa]);
238c8a55531SSebastian Grimberg }
239c8a55531SSebastian Grimberg }
240c8a55531SSebastian Grimberg // Remainder of rows, all columns
241c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) {
242c8a55531SSebastian Grimberg for (CeedInt j = (J / JJ) * JJ; j < J; j++) {
243ad70ee2cSJeremy L Thompson const CeedScalar tq = t[j * t_stride_0 + b * t_stride_1];
244ad70ee2cSJeremy L Thompson
245c8a55531SSebastian Grimberg for (CeedInt a = A_break; a < A; a++) v[a * J + j] += tq * u[a * B + b];
246c8a55531SSebastian Grimberg }
247c8a55531SSebastian Grimberg }
248c8a55531SSebastian Grimberg return CEED_ERROR_SUCCESS;
249c8a55531SSebastian Grimberg }
250c8a55531SSebastian Grimberg
251c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
252c8a55531SSebastian Grimberg // Tensor Contract - Common Sizes
253c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
CeedTensorContract_Avx_Blocked_4_8(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v)254c8a55531SSebastian Grimberg static int CeedTensorContract_Avx_Blocked_4_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
255c8a55531SSebastian Grimberg CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
256c8a55531SSebastian Grimberg return CeedTensorContract_Avx_Blocked(contract, A, B, C, J, t, t_mode, add, u, v, 4, 8);
257c8a55531SSebastian Grimberg }
CeedTensorContract_Avx_Remainder_8_8(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v)258c8a55531SSebastian Grimberg static int CeedTensorContract_Avx_Remainder_8_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
259c8a55531SSebastian Grimberg CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
260c8a55531SSebastian Grimberg return CeedTensorContract_Avx_Remainder(contract, A, B, C, J, t, t_mode, add, u, v, 8, 8);
261c8a55531SSebastian Grimberg }
CeedTensorContract_Avx_Single_4_8(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v)262c8a55531SSebastian Grimberg static int CeedTensorContract_Avx_Single_4_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
263c8a55531SSebastian Grimberg CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
264c8a55531SSebastian Grimberg return CeedTensorContract_Avx_Single(contract, A, B, C, J, t, t_mode, add, u, v, 4, 8);
265c8a55531SSebastian Grimberg }
266c8a55531SSebastian Grimberg
267c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
268c8a55531SSebastian Grimberg // Tensor Contract Apply
269c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
CeedTensorContractApply_Avx(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v)270c8a55531SSebastian Grimberg static int CeedTensorContractApply_Avx(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
271c8a55531SSebastian Grimberg CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
272c8a55531SSebastian Grimberg const CeedInt blk_size = 8;
273c8a55531SSebastian Grimberg
274c8a55531SSebastian Grimberg if (!add) {
275c8a55531SSebastian Grimberg for (CeedInt q = 0; q < A * J * C; q++) v[q] = (CeedScalar)0.0;
276c8a55531SSebastian Grimberg }
277c8a55531SSebastian Grimberg
278c8a55531SSebastian Grimberg if (C == 1) {
279c8a55531SSebastian Grimberg // Serial C=1 Case
280c8a55531SSebastian Grimberg CeedTensorContract_Avx_Single_4_8(contract, A, B, C, J, t, t_mode, true, u, v);
281c8a55531SSebastian Grimberg } else {
282c8a55531SSebastian Grimberg // Blocks of 8 columns
283c8a55531SSebastian Grimberg if (C >= blk_size) CeedTensorContract_Avx_Blocked_4_8(contract, A, B, C, J, t, t_mode, true, u, v);
284c8a55531SSebastian Grimberg // Remainder of columns
285c8a55531SSebastian Grimberg if (C % blk_size) CeedTensorContract_Avx_Remainder_8_8(contract, A, B, C, J, t, t_mode, true, u, v);
286c8a55531SSebastian Grimberg }
287c8a55531SSebastian Grimberg return CEED_ERROR_SUCCESS;
288c8a55531SSebastian Grimberg }
289c8a55531SSebastian Grimberg
290c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
291c8a55531SSebastian Grimberg // Tensor Contract Create
292c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
CeedTensorContractCreate_Avx(CeedTensorContract contract)293a71faab1SSebastian Grimberg int CeedTensorContractCreate_Avx(CeedTensorContract contract) {
2946e536b99SJeremy L Thompson CeedCallBackend(CeedSetBackendFunction(CeedTensorContractReturnCeed(contract), "TensorContract", contract, "Apply", CeedTensorContractApply_Avx));
295c8a55531SSebastian Grimberg return CEED_ERROR_SUCCESS;
296c8a55531SSebastian Grimberg }
297c8a55531SSebastian Grimberg
298c8a55531SSebastian Grimberg //------------------------------------------------------------------------------
299