1c8a55531SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2c8a55531SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3c8a55531SSebastian Grimberg // 4c8a55531SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5c8a55531SSebastian Grimberg // 6c8a55531SSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7c8a55531SSebastian Grimberg 8c8a55531SSebastian Grimberg #include <ceed.h> 9c8a55531SSebastian Grimberg #include <ceed/backend.h> 10c8a55531SSebastian Grimberg #include <immintrin.h> 11c8a55531SSebastian Grimberg #include <stdbool.h> 12c8a55531SSebastian Grimberg 1394b7b29bSJeremy L Thompson #ifdef CEED_F64_H 14c8a55531SSebastian Grimberg #define rtype __m256d 15c8a55531SSebastian Grimberg #define loadu _mm256_loadu_pd 16c8a55531SSebastian Grimberg #define storeu _mm256_storeu_pd 17c8a55531SSebastian Grimberg #define set _mm256_set_pd 18c8a55531SSebastian Grimberg #define set1 _mm256_set1_pd 19c8a55531SSebastian Grimberg // c += a * b 20c8a55531SSebastian Grimberg #ifdef __FMA__ 21c8a55531SSebastian Grimberg #define fmadd(c, a, b) (c) = _mm256_fmadd_pd((a), (b), (c)) 22c8a55531SSebastian Grimberg #else 23c8a55531SSebastian Grimberg #define fmadd(c, a, b) (c) += _mm256_mul_pd((a), (b)) 24c8a55531SSebastian Grimberg #endif 25c8a55531SSebastian Grimberg #else 26c8a55531SSebastian Grimberg #define rtype __m128 27c8a55531SSebastian Grimberg #define loadu _mm_loadu_ps 28c8a55531SSebastian Grimberg #define storeu _mm_storeu_ps 29c8a55531SSebastian Grimberg #define set _mm_set_ps 30c8a55531SSebastian Grimberg #define set1 _mm_set1_ps 31c8a55531SSebastian Grimberg // c += a * b 32c8a55531SSebastian Grimberg #ifdef __FMA__ 33c8a55531SSebastian Grimberg #define fmadd(c, a, b) (c) = _mm_fmadd_ps((a), (b), (c)) 34c8a55531SSebastian Grimberg #else 35c8a55531SSebastian Grimberg #define fmadd(c, a, b) (c) += _mm_mul_ps((a), (b)) 36c8a55531SSebastian Grimberg #endif 37c8a55531SSebastian Grimberg #endif 38c8a55531SSebastian Grimberg 39c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 40c8a55531SSebastian Grimberg // Blocked Tensor Contract 41c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 42c8a55531SSebastian Grimberg static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, 43c8a55531SSebastian Grimberg const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add, 44c8a55531SSebastian Grimberg const CeedScalar *restrict u, CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) { 45c8a55531SSebastian Grimberg CeedInt t_stride_0 = B, t_stride_1 = 1; 46ad70ee2cSJeremy L Thompson 47c8a55531SSebastian Grimberg if (t_mode == CEED_TRANSPOSE) { 48c8a55531SSebastian Grimberg t_stride_0 = 1; 49c8a55531SSebastian Grimberg t_stride_1 = J; 50c8a55531SSebastian Grimberg } 51c8a55531SSebastian Grimberg 52c8a55531SSebastian Grimberg for (CeedInt a = 0; a < A; a++) { 53c8a55531SSebastian Grimberg // Blocks of 4 rows 54c8a55531SSebastian Grimberg for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) { 55c8a55531SSebastian Grimberg for (CeedInt c = 0; c < (C / CC) * CC; c += CC) { 56c8a55531SSebastian Grimberg rtype vv[JJ][CC / 4]; // Output tile to be held in registers 57c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ; jj++) { 58c8a55531SSebastian Grimberg for (CeedInt cc = 0; cc < CC / 4; cc++) vv[jj][cc] = loadu(&v[(a * J + j + jj) * C + c + cc * 4]); 59c8a55531SSebastian Grimberg } 60c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) { 61c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ; jj++) { // unroll 62c8a55531SSebastian Grimberg rtype tqv = set1(t[(j + jj) * t_stride_0 + b * t_stride_1]); 63c8a55531SSebastian Grimberg for (CeedInt cc = 0; cc < CC / 4; cc++) { // unroll 64c8a55531SSebastian Grimberg fmadd(vv[jj][cc], tqv, loadu(&u[(a * B + b) * C + c + cc * 4])); 65c8a55531SSebastian Grimberg } 66c8a55531SSebastian Grimberg } 67c8a55531SSebastian Grimberg } 68c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ; jj++) { 69c8a55531SSebastian Grimberg for (CeedInt cc = 0; cc < CC / 4; cc++) storeu(&v[(a * J + j + jj) * C + c + cc * 4], vv[jj][cc]); 70c8a55531SSebastian Grimberg } 71c8a55531SSebastian Grimberg } 72c8a55531SSebastian Grimberg } 73c8a55531SSebastian Grimberg // Remainder of rows 74ad70ee2cSJeremy L Thompson const CeedInt j = (J / JJ) * JJ; 75ad70ee2cSJeremy L Thompson 76c8a55531SSebastian Grimberg if (j < J) { 77c8a55531SSebastian Grimberg for (CeedInt c = 0; c < (C / CC) * CC; c += CC) { 78c8a55531SSebastian Grimberg rtype vv[JJ][CC / 4]; // Output tile to be held in registers 79ad70ee2cSJeremy L Thompson 80c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < J - j; jj++) { 81c8a55531SSebastian Grimberg for (CeedInt cc = 0; cc < CC / 4; cc++) vv[jj][cc] = loadu(&v[(a * J + j + jj) * C + c + cc * 4]); 82c8a55531SSebastian Grimberg } 83c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) { 84c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < J - j; jj++) { // doesn't unroll 85c8a55531SSebastian Grimberg rtype tqv = set1(t[(j + jj) * t_stride_0 + b * t_stride_1]); 86ad70ee2cSJeremy L Thompson 87c8a55531SSebastian Grimberg for (CeedInt cc = 0; cc < CC / 4; cc++) { // unroll 88c8a55531SSebastian Grimberg fmadd(vv[jj][cc], tqv, loadu(&u[(a * B + b) * C + c + cc * 4])); 89c8a55531SSebastian Grimberg } 90c8a55531SSebastian Grimberg } 91c8a55531SSebastian Grimberg } 92c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < J - j; jj++) { 93c8a55531SSebastian Grimberg for (CeedInt cc = 0; cc < CC / 4; cc++) storeu(&v[(a * J + j + jj) * C + c + cc * 4], vv[jj][cc]); 94c8a55531SSebastian Grimberg } 95c8a55531SSebastian Grimberg } 96c8a55531SSebastian Grimberg } 97c8a55531SSebastian Grimberg } 98c8a55531SSebastian Grimberg return CEED_ERROR_SUCCESS; 99c8a55531SSebastian Grimberg } 100c8a55531SSebastian Grimberg 101c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 102c8a55531SSebastian Grimberg // Serial Tensor Contract Remainder 103c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 104c8a55531SSebastian Grimberg static inline int CeedTensorContract_Avx_Remainder(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, 105c8a55531SSebastian Grimberg const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add, 106c8a55531SSebastian Grimberg const CeedScalar *restrict u, CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) { 107c8a55531SSebastian Grimberg CeedInt t_stride_0 = B, t_stride_1 = 1; 108ad70ee2cSJeremy L Thompson 109c8a55531SSebastian Grimberg if (t_mode == CEED_TRANSPOSE) { 110c8a55531SSebastian Grimberg t_stride_0 = 1; 111c8a55531SSebastian Grimberg t_stride_1 = J; 112c8a55531SSebastian Grimberg } 113c8a55531SSebastian Grimberg 114ad70ee2cSJeremy L Thompson const CeedInt J_break = J % JJ ? (J / JJ) * JJ : (J / JJ - 1) * JJ; 115ad70ee2cSJeremy L Thompson 116c8a55531SSebastian Grimberg for (CeedInt a = 0; a < A; a++) { 117c8a55531SSebastian Grimberg // Blocks of 4 columns 118c8a55531SSebastian Grimberg for (CeedInt c = (C / CC) * CC; c < C; c += 4) { 119c8a55531SSebastian Grimberg // Blocks of 4 rows 120c8a55531SSebastian Grimberg for (CeedInt j = 0; j < J_break; j += JJ) { 121c8a55531SSebastian Grimberg rtype vv[JJ]; // Output tile to be held in registers 122c8a55531SSebastian Grimberg 123ad70ee2cSJeremy L Thompson for (CeedInt jj = 0; jj < JJ; jj++) vv[jj] = loadu(&v[(a * J + j + jj) * C + c]); 124c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) { 125c8a55531SSebastian Grimberg rtype tqu; 126ad70ee2cSJeremy L Thompson 127c8a55531SSebastian Grimberg if (C - c == 1) tqu = set(0.0, 0.0, 0.0, u[(a * B + b) * C + c + 0]); 128c8a55531SSebastian Grimberg else if (C - c == 2) tqu = set(0.0, 0.0, u[(a * B + b) * C + c + 1], u[(a * B + b) * C + c + 0]); 129c8a55531SSebastian Grimberg else if (C - c == 3) tqu = set(0.0, u[(a * B + b) * C + c + 2], u[(a * B + b) * C + c + 1], u[(a * B + b) * C + c + 0]); 130c8a55531SSebastian Grimberg else tqu = loadu(&u[(a * B + b) * C + c]); 131c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ; jj++) { // unroll 132c8a55531SSebastian Grimberg fmadd(vv[jj], tqu, set1(t[(j + jj) * t_stride_0 + b * t_stride_1])); 133c8a55531SSebastian Grimberg } 134c8a55531SSebastian Grimberg } 135c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ; jj++) storeu(&v[(a * J + j + jj) * C + c], vv[jj]); 136c8a55531SSebastian Grimberg } 137c8a55531SSebastian Grimberg } 138c8a55531SSebastian Grimberg // Remainder of rows, all columns 139c8a55531SSebastian Grimberg for (CeedInt j = J_break; j < J; j++) { 140c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) { 141ad70ee2cSJeremy L Thompson const CeedScalar tq = t[j * t_stride_0 + b * t_stride_1]; 142ad70ee2cSJeremy L Thompson 143c8a55531SSebastian Grimberg for (CeedInt c = (C / CC) * CC; c < C; c++) v[(a * J + j) * C + c] += tq * u[(a * B + b) * C + c]; 144c8a55531SSebastian Grimberg } 145c8a55531SSebastian Grimberg } 146c8a55531SSebastian Grimberg } 147c8a55531SSebastian Grimberg return CEED_ERROR_SUCCESS; 148c8a55531SSebastian Grimberg } 149c8a55531SSebastian Grimberg 150c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 151c8a55531SSebastian Grimberg // Serial Tensor Contract C=1 152c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 153c8a55531SSebastian Grimberg static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 154c8a55531SSebastian Grimberg CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v, 155c8a55531SSebastian Grimberg const CeedInt AA, const CeedInt JJ) { 156c8a55531SSebastian Grimberg CeedInt t_stride_0 = B, t_stride_1 = 1; 157ad70ee2cSJeremy L Thompson 158c8a55531SSebastian Grimberg if (t_mode == CEED_TRANSPOSE) { 159c8a55531SSebastian Grimberg t_stride_0 = 1; 160c8a55531SSebastian Grimberg t_stride_1 = J; 161c8a55531SSebastian Grimberg } 162c8a55531SSebastian Grimberg 163c8a55531SSebastian Grimberg // Blocks of 4 rows 164c8a55531SSebastian Grimberg for (CeedInt a = 0; a < (A / AA) * AA; a += AA) { 165c8a55531SSebastian Grimberg for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) { 166c8a55531SSebastian Grimberg rtype vv[AA][JJ / 4]; // Output tile to be held in registers 167ad70ee2cSJeremy L Thompson 168c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < AA; aa++) { 169c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ / 4; jj++) vv[aa][jj] = loadu(&v[(a + aa) * J + j + jj * 4]); 170c8a55531SSebastian Grimberg } 171c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) { 172c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ / 4; jj++) { // unroll 173c8a55531SSebastian Grimberg rtype tqv = set(t[(j + jj * 4 + 3) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 2) * t_stride_0 + b * t_stride_1], 174c8a55531SSebastian Grimberg t[(j + jj * 4 + 1) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 0) * t_stride_0 + b * t_stride_1]); 175ad70ee2cSJeremy L Thompson 176c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < AA; aa++) { // unroll 177c8a55531SSebastian Grimberg fmadd(vv[aa][jj], tqv, set1(u[(a + aa) * B + b])); 178c8a55531SSebastian Grimberg } 179c8a55531SSebastian Grimberg } 180c8a55531SSebastian Grimberg } 181c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < AA; aa++) { 182c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ / 4; jj++) storeu(&v[(a + aa) * J + j + jj * 4], vv[aa][jj]); 183c8a55531SSebastian Grimberg } 184c8a55531SSebastian Grimberg } 185c8a55531SSebastian Grimberg } 186c8a55531SSebastian Grimberg // Remainder of rows 187ad70ee2cSJeremy L Thompson const CeedInt a = (A / AA) * AA; 188ad70ee2cSJeremy L Thompson 189c8a55531SSebastian Grimberg for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) { 190c8a55531SSebastian Grimberg rtype vv[AA][JJ / 4]; // Output tile to be held in registers 191ad70ee2cSJeremy L Thompson 192c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < A - a; aa++) { 193c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ / 4; jj++) vv[aa][jj] = loadu(&v[(a + aa) * J + j + jj * 4]); 194c8a55531SSebastian Grimberg } 195c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) { 196c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ / 4; jj++) { // unroll 197c8a55531SSebastian Grimberg rtype tqv = set(t[(j + jj * 4 + 3) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 2) * t_stride_0 + b * t_stride_1], 198c8a55531SSebastian Grimberg t[(j + jj * 4 + 1) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 0) * t_stride_0 + b * t_stride_1]); 199ad70ee2cSJeremy L Thompson 200c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < A - a; aa++) { // unroll 201c8a55531SSebastian Grimberg fmadd(vv[aa][jj], tqv, set1(u[(a + aa) * B + b])); 202c8a55531SSebastian Grimberg } 203c8a55531SSebastian Grimberg } 204c8a55531SSebastian Grimberg } 205c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < A - a; aa++) { 206c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ / 4; jj++) storeu(&v[(a + aa) * J + j + jj * 4], vv[aa][jj]); 207c8a55531SSebastian Grimberg } 208c8a55531SSebastian Grimberg } 209c8a55531SSebastian Grimberg // Column remainder 210ad70ee2cSJeremy L Thompson const CeedInt A_break = A % AA ? (A / AA) * AA : (A / AA - 1) * AA; 211ad70ee2cSJeremy L Thompson 212c8a55531SSebastian Grimberg // Blocks of 4 columns 213c8a55531SSebastian Grimberg for (CeedInt j = (J / JJ) * JJ; j < J; j += 4) { 214c8a55531SSebastian Grimberg // Blocks of 4 rows 215c8a55531SSebastian Grimberg for (CeedInt a = 0; a < A_break; a += AA) { 216c8a55531SSebastian Grimberg rtype vv[AA]; // Output tile to be held in registers 217c8a55531SSebastian Grimberg 218ad70ee2cSJeremy L Thompson for (CeedInt aa = 0; aa < AA; aa++) vv[aa] = loadu(&v[(a + aa) * J + j]); 219c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) { 220c8a55531SSebastian Grimberg rtype tqv; 221ad70ee2cSJeremy L Thompson 222c8a55531SSebastian Grimberg if (J - j == 1) { 223c8a55531SSebastian Grimberg tqv = set(0.0, 0.0, 0.0, t[(j + 0) * t_stride_0 + b * t_stride_1]); 224c8a55531SSebastian Grimberg } else if (J - j == 2) { 225c8a55531SSebastian Grimberg tqv = set(0.0, 0.0, t[(j + 1) * t_stride_0 + b * t_stride_1], t[(j + 0) * t_stride_0 + b * t_stride_1]); 226c8a55531SSebastian Grimberg } else if (J - 3 == j) { 227c8a55531SSebastian Grimberg tqv = 228c8a55531SSebastian Grimberg set(0.0, t[(j + 2) * t_stride_0 + b * t_stride_1], t[(j + 1) * t_stride_0 + b * t_stride_1], t[(j + 0) * t_stride_0 + b * t_stride_1]); 229c8a55531SSebastian Grimberg } else { 230c8a55531SSebastian Grimberg tqv = set(t[(j + 3) * t_stride_0 + b * t_stride_1], t[(j + 2) * t_stride_0 + b * t_stride_1], t[(j + 1) * t_stride_0 + b * t_stride_1], 231c8a55531SSebastian Grimberg t[(j + 0) * t_stride_0 + b * t_stride_1]); 232c8a55531SSebastian Grimberg } 233c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < AA; aa++) { // unroll 234c8a55531SSebastian Grimberg fmadd(vv[aa], tqv, set1(u[(a + aa) * B + b])); 235c8a55531SSebastian Grimberg } 236c8a55531SSebastian Grimberg } 237c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < AA; aa++) storeu(&v[(a + aa) * J + j], vv[aa]); 238c8a55531SSebastian Grimberg } 239c8a55531SSebastian Grimberg } 240c8a55531SSebastian Grimberg // Remainder of rows, all columns 241c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) { 242c8a55531SSebastian Grimberg for (CeedInt j = (J / JJ) * JJ; j < J; j++) { 243ad70ee2cSJeremy L Thompson const CeedScalar tq = t[j * t_stride_0 + b * t_stride_1]; 244ad70ee2cSJeremy L Thompson 245c8a55531SSebastian Grimberg for (CeedInt a = A_break; a < A; a++) v[a * J + j] += tq * u[a * B + b]; 246c8a55531SSebastian Grimberg } 247c8a55531SSebastian Grimberg } 248c8a55531SSebastian Grimberg return CEED_ERROR_SUCCESS; 249c8a55531SSebastian Grimberg } 250c8a55531SSebastian Grimberg 251c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 252c8a55531SSebastian Grimberg // Tensor Contract - Common Sizes 253c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 254c8a55531SSebastian Grimberg static int CeedTensorContract_Avx_Blocked_4_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 255c8a55531SSebastian Grimberg CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 256c8a55531SSebastian Grimberg return CeedTensorContract_Avx_Blocked(contract, A, B, C, J, t, t_mode, add, u, v, 4, 8); 257c8a55531SSebastian Grimberg } 258c8a55531SSebastian Grimberg static int CeedTensorContract_Avx_Remainder_8_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 259c8a55531SSebastian Grimberg CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 260c8a55531SSebastian Grimberg return CeedTensorContract_Avx_Remainder(contract, A, B, C, J, t, t_mode, add, u, v, 8, 8); 261c8a55531SSebastian Grimberg } 262c8a55531SSebastian Grimberg static int CeedTensorContract_Avx_Single_4_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 263c8a55531SSebastian Grimberg CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 264c8a55531SSebastian Grimberg return CeedTensorContract_Avx_Single(contract, A, B, C, J, t, t_mode, add, u, v, 4, 8); 265c8a55531SSebastian Grimberg } 266c8a55531SSebastian Grimberg 267c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 268c8a55531SSebastian Grimberg // Tensor Contract Apply 269c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 270c8a55531SSebastian Grimberg static int CeedTensorContractApply_Avx(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 271c8a55531SSebastian Grimberg CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 272c8a55531SSebastian Grimberg const CeedInt blk_size = 8; 273c8a55531SSebastian Grimberg 274c8a55531SSebastian Grimberg if (!add) { 275c8a55531SSebastian Grimberg for (CeedInt q = 0; q < A * J * C; q++) v[q] = (CeedScalar)0.0; 276c8a55531SSebastian Grimberg } 277c8a55531SSebastian Grimberg 278c8a55531SSebastian Grimberg if (C == 1) { 279c8a55531SSebastian Grimberg // Serial C=1 Case 280c8a55531SSebastian Grimberg CeedTensorContract_Avx_Single_4_8(contract, A, B, C, J, t, t_mode, true, u, v); 281c8a55531SSebastian Grimberg } else { 282c8a55531SSebastian Grimberg // Blocks of 8 columns 283c8a55531SSebastian Grimberg if (C >= blk_size) CeedTensorContract_Avx_Blocked_4_8(contract, A, B, C, J, t, t_mode, true, u, v); 284c8a55531SSebastian Grimberg // Remainder of columns 285c8a55531SSebastian Grimberg if (C % blk_size) CeedTensorContract_Avx_Remainder_8_8(contract, A, B, C, J, t, t_mode, true, u, v); 286c8a55531SSebastian Grimberg } 287c8a55531SSebastian Grimberg return CEED_ERROR_SUCCESS; 288c8a55531SSebastian Grimberg } 289c8a55531SSebastian Grimberg 290c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 291c8a55531SSebastian Grimberg // Tensor Contract Create 292c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 293*a71faab1SSebastian Grimberg int CeedTensorContractCreate_Avx(CeedTensorContract contract) { 294c8a55531SSebastian Grimberg Ceed ceed; 295ad70ee2cSJeremy L Thompson 296c8a55531SSebastian Grimberg CeedCallBackend(CeedTensorContractGetCeed(contract, &ceed)); 297c8a55531SSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", CeedTensorContractApply_Avx)); 298c8a55531SSebastian Grimberg return CEED_ERROR_SUCCESS; 299c8a55531SSebastian Grimberg } 300c8a55531SSebastian Grimberg 301c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 302