1*c8a55531SSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2*c8a55531SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*c8a55531SSebastian Grimberg // 4*c8a55531SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5*c8a55531SSebastian Grimberg // 6*c8a55531SSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7*c8a55531SSebastian Grimberg 8*c8a55531SSebastian Grimberg #include <ceed.h> 9*c8a55531SSebastian Grimberg #include <ceed/backend.h> 10*c8a55531SSebastian Grimberg #include <immintrin.h> 11*c8a55531SSebastian Grimberg #include <stdbool.h> 12*c8a55531SSebastian Grimberg 13*c8a55531SSebastian Grimberg #ifdef _ceed_f64_h 14*c8a55531SSebastian Grimberg #define rtype __m256d 15*c8a55531SSebastian Grimberg #define loadu _mm256_loadu_pd 16*c8a55531SSebastian Grimberg #define storeu _mm256_storeu_pd 17*c8a55531SSebastian Grimberg #define set _mm256_set_pd 18*c8a55531SSebastian Grimberg #define set1 _mm256_set1_pd 19*c8a55531SSebastian Grimberg // c += a * b 20*c8a55531SSebastian Grimberg #ifdef __FMA__ 21*c8a55531SSebastian Grimberg #define fmadd(c, a, b) (c) = _mm256_fmadd_pd((a), (b), (c)) 22*c8a55531SSebastian Grimberg #else 23*c8a55531SSebastian Grimberg #define fmadd(c, a, b) (c) += _mm256_mul_pd((a), (b)) 24*c8a55531SSebastian Grimberg #endif 25*c8a55531SSebastian Grimberg #else 26*c8a55531SSebastian Grimberg #define rtype __m128 27*c8a55531SSebastian Grimberg #define loadu _mm_loadu_ps 28*c8a55531SSebastian Grimberg #define storeu _mm_storeu_ps 29*c8a55531SSebastian Grimberg #define set _mm_set_ps 30*c8a55531SSebastian Grimberg #define set1 _mm_set1_ps 31*c8a55531SSebastian Grimberg // c += a * b 32*c8a55531SSebastian Grimberg #ifdef __FMA__ 33*c8a55531SSebastian Grimberg #define fmadd(c, a, b) (c) = _mm_fmadd_ps((a), (b), (c)) 34*c8a55531SSebastian Grimberg #else 35*c8a55531SSebastian Grimberg #define fmadd(c, a, b) (c) += _mm_mul_ps((a), (b)) 36*c8a55531SSebastian Grimberg #endif 37*c8a55531SSebastian Grimberg #endif 38*c8a55531SSebastian Grimberg 39*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 40*c8a55531SSebastian Grimberg // Blocked Tensor Contract 41*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 42*c8a55531SSebastian Grimberg static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, 43*c8a55531SSebastian Grimberg const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add, 44*c8a55531SSebastian Grimberg const CeedScalar *restrict u, CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) { 45*c8a55531SSebastian Grimberg CeedInt t_stride_0 = B, t_stride_1 = 1; 46*c8a55531SSebastian Grimberg if (t_mode == CEED_TRANSPOSE) { 47*c8a55531SSebastian Grimberg t_stride_0 = 1; 48*c8a55531SSebastian Grimberg t_stride_1 = J; 49*c8a55531SSebastian Grimberg } 50*c8a55531SSebastian Grimberg 51*c8a55531SSebastian Grimberg for (CeedInt a = 0; a < A; a++) { 52*c8a55531SSebastian Grimberg // Blocks of 4 rows 53*c8a55531SSebastian Grimberg for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) { 54*c8a55531SSebastian Grimberg for (CeedInt c = 0; c < (C / CC) * CC; c += CC) { 55*c8a55531SSebastian Grimberg rtype vv[JJ][CC / 4]; // Output tile to be held in registers 56*c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ; jj++) { 57*c8a55531SSebastian Grimberg for (CeedInt cc = 0; cc < CC / 4; cc++) vv[jj][cc] = loadu(&v[(a * J + j + jj) * C + c + cc * 4]); 58*c8a55531SSebastian Grimberg } 59*c8a55531SSebastian Grimberg 60*c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) { 61*c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ; jj++) { // unroll 62*c8a55531SSebastian Grimberg rtype tqv = set1(t[(j + jj) * t_stride_0 + b * t_stride_1]); 63*c8a55531SSebastian Grimberg for (CeedInt cc = 0; cc < CC / 4; cc++) { // unroll 64*c8a55531SSebastian Grimberg fmadd(vv[jj][cc], tqv, loadu(&u[(a * B + b) * C + c + cc * 4])); 65*c8a55531SSebastian Grimberg } 66*c8a55531SSebastian Grimberg } 67*c8a55531SSebastian Grimberg } 68*c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ; jj++) { 69*c8a55531SSebastian Grimberg for (CeedInt cc = 0; cc < CC / 4; cc++) storeu(&v[(a * J + j + jj) * C + c + cc * 4], vv[jj][cc]); 70*c8a55531SSebastian Grimberg } 71*c8a55531SSebastian Grimberg } 72*c8a55531SSebastian Grimberg } 73*c8a55531SSebastian Grimberg // Remainder of rows 74*c8a55531SSebastian Grimberg CeedInt j = (J / JJ) * JJ; 75*c8a55531SSebastian Grimberg if (j < J) { 76*c8a55531SSebastian Grimberg for (CeedInt c = 0; c < (C / CC) * CC; c += CC) { 77*c8a55531SSebastian Grimberg rtype vv[JJ][CC / 4]; // Output tile to be held in registers 78*c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < J - j; jj++) { 79*c8a55531SSebastian Grimberg for (CeedInt cc = 0; cc < CC / 4; cc++) vv[jj][cc] = loadu(&v[(a * J + j + jj) * C + c + cc * 4]); 80*c8a55531SSebastian Grimberg } 81*c8a55531SSebastian Grimberg 82*c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) { 83*c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < J - j; jj++) { // doesn't unroll 84*c8a55531SSebastian Grimberg rtype tqv = set1(t[(j + jj) * t_stride_0 + b * t_stride_1]); 85*c8a55531SSebastian Grimberg for (CeedInt cc = 0; cc < CC / 4; cc++) { // unroll 86*c8a55531SSebastian Grimberg fmadd(vv[jj][cc], tqv, loadu(&u[(a * B + b) * C + c + cc * 4])); 87*c8a55531SSebastian Grimberg } 88*c8a55531SSebastian Grimberg } 89*c8a55531SSebastian Grimberg } 90*c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < J - j; jj++) { 91*c8a55531SSebastian Grimberg for (CeedInt cc = 0; cc < CC / 4; cc++) storeu(&v[(a * J + j + jj) * C + c + cc * 4], vv[jj][cc]); 92*c8a55531SSebastian Grimberg } 93*c8a55531SSebastian Grimberg } 94*c8a55531SSebastian Grimberg } 95*c8a55531SSebastian Grimberg } 96*c8a55531SSebastian Grimberg return CEED_ERROR_SUCCESS; 97*c8a55531SSebastian Grimberg } 98*c8a55531SSebastian Grimberg 99*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 100*c8a55531SSebastian Grimberg // Serial Tensor Contract Remainder 101*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 102*c8a55531SSebastian Grimberg static inline int CeedTensorContract_Avx_Remainder(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, 103*c8a55531SSebastian Grimberg const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add, 104*c8a55531SSebastian Grimberg const CeedScalar *restrict u, CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) { 105*c8a55531SSebastian Grimberg CeedInt t_stride_0 = B, t_stride_1 = 1; 106*c8a55531SSebastian Grimberg if (t_mode == CEED_TRANSPOSE) { 107*c8a55531SSebastian Grimberg t_stride_0 = 1; 108*c8a55531SSebastian Grimberg t_stride_1 = J; 109*c8a55531SSebastian Grimberg } 110*c8a55531SSebastian Grimberg 111*c8a55531SSebastian Grimberg CeedInt J_break = J % JJ ? (J / JJ) * JJ : (J / JJ - 1) * JJ; 112*c8a55531SSebastian Grimberg for (CeedInt a = 0; a < A; a++) { 113*c8a55531SSebastian Grimberg // Blocks of 4 columns 114*c8a55531SSebastian Grimberg for (CeedInt c = (C / CC) * CC; c < C; c += 4) { 115*c8a55531SSebastian Grimberg // Blocks of 4 rows 116*c8a55531SSebastian Grimberg for (CeedInt j = 0; j < J_break; j += JJ) { 117*c8a55531SSebastian Grimberg rtype vv[JJ]; // Output tile to be held in registers 118*c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ; jj++) vv[jj] = loadu(&v[(a * J + j + jj) * C + c]); 119*c8a55531SSebastian Grimberg 120*c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) { 121*c8a55531SSebastian Grimberg rtype tqu; 122*c8a55531SSebastian Grimberg if (C - c == 1) tqu = set(0.0, 0.0, 0.0, u[(a * B + b) * C + c + 0]); 123*c8a55531SSebastian Grimberg else if (C - c == 2) tqu = set(0.0, 0.0, u[(a * B + b) * C + c + 1], u[(a * B + b) * C + c + 0]); 124*c8a55531SSebastian Grimberg else if (C - c == 3) tqu = set(0.0, u[(a * B + b) * C + c + 2], u[(a * B + b) * C + c + 1], u[(a * B + b) * C + c + 0]); 125*c8a55531SSebastian Grimberg else tqu = loadu(&u[(a * B + b) * C + c]); 126*c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ; jj++) { // unroll 127*c8a55531SSebastian Grimberg fmadd(vv[jj], tqu, set1(t[(j + jj) * t_stride_0 + b * t_stride_1])); 128*c8a55531SSebastian Grimberg } 129*c8a55531SSebastian Grimberg } 130*c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ; jj++) storeu(&v[(a * J + j + jj) * C + c], vv[jj]); 131*c8a55531SSebastian Grimberg } 132*c8a55531SSebastian Grimberg } 133*c8a55531SSebastian Grimberg // Remainder of rows, all columns 134*c8a55531SSebastian Grimberg for (CeedInt j = J_break; j < J; j++) { 135*c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) { 136*c8a55531SSebastian Grimberg CeedScalar tq = t[j * t_stride_0 + b * t_stride_1]; 137*c8a55531SSebastian Grimberg for (CeedInt c = (C / CC) * CC; c < C; c++) v[(a * J + j) * C + c] += tq * u[(a * B + b) * C + c]; 138*c8a55531SSebastian Grimberg } 139*c8a55531SSebastian Grimberg } 140*c8a55531SSebastian Grimberg } 141*c8a55531SSebastian Grimberg return CEED_ERROR_SUCCESS; 142*c8a55531SSebastian Grimberg } 143*c8a55531SSebastian Grimberg 144*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 145*c8a55531SSebastian Grimberg // Serial Tensor Contract C=1 146*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 147*c8a55531SSebastian Grimberg static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 148*c8a55531SSebastian Grimberg CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v, 149*c8a55531SSebastian Grimberg const CeedInt AA, const CeedInt JJ) { 150*c8a55531SSebastian Grimberg CeedInt t_stride_0 = B, t_stride_1 = 1; 151*c8a55531SSebastian Grimberg if (t_mode == CEED_TRANSPOSE) { 152*c8a55531SSebastian Grimberg t_stride_0 = 1; 153*c8a55531SSebastian Grimberg t_stride_1 = J; 154*c8a55531SSebastian Grimberg } 155*c8a55531SSebastian Grimberg 156*c8a55531SSebastian Grimberg // Blocks of 4 rows 157*c8a55531SSebastian Grimberg for (CeedInt a = 0; a < (A / AA) * AA; a += AA) { 158*c8a55531SSebastian Grimberg for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) { 159*c8a55531SSebastian Grimberg rtype vv[AA][JJ / 4]; // Output tile to be held in registers 160*c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < AA; aa++) { 161*c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ / 4; jj++) vv[aa][jj] = loadu(&v[(a + aa) * J + j + jj * 4]); 162*c8a55531SSebastian Grimberg } 163*c8a55531SSebastian Grimberg 164*c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) { 165*c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ / 4; jj++) { // unroll 166*c8a55531SSebastian Grimberg rtype tqv = set(t[(j + jj * 4 + 3) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 2) * t_stride_0 + b * t_stride_1], 167*c8a55531SSebastian Grimberg t[(j + jj * 4 + 1) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 0) * t_stride_0 + b * t_stride_1]); 168*c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < AA; aa++) { // unroll 169*c8a55531SSebastian Grimberg fmadd(vv[aa][jj], tqv, set1(u[(a + aa) * B + b])); 170*c8a55531SSebastian Grimberg } 171*c8a55531SSebastian Grimberg } 172*c8a55531SSebastian Grimberg } 173*c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < AA; aa++) { 174*c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ / 4; jj++) storeu(&v[(a + aa) * J + j + jj * 4], vv[aa][jj]); 175*c8a55531SSebastian Grimberg } 176*c8a55531SSebastian Grimberg } 177*c8a55531SSebastian Grimberg } 178*c8a55531SSebastian Grimberg // Remainder of rows 179*c8a55531SSebastian Grimberg CeedInt a = (A / AA) * AA; 180*c8a55531SSebastian Grimberg for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) { 181*c8a55531SSebastian Grimberg rtype vv[AA][JJ / 4]; // Output tile to be held in registers 182*c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < A - a; aa++) { 183*c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ / 4; jj++) vv[aa][jj] = loadu(&v[(a + aa) * J + j + jj * 4]); 184*c8a55531SSebastian Grimberg } 185*c8a55531SSebastian Grimberg 186*c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) { 187*c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ / 4; jj++) { // unroll 188*c8a55531SSebastian Grimberg rtype tqv = set(t[(j + jj * 4 + 3) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 2) * t_stride_0 + b * t_stride_1], 189*c8a55531SSebastian Grimberg t[(j + jj * 4 + 1) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 0) * t_stride_0 + b * t_stride_1]); 190*c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < A - a; aa++) { // unroll 191*c8a55531SSebastian Grimberg fmadd(vv[aa][jj], tqv, set1(u[(a + aa) * B + b])); 192*c8a55531SSebastian Grimberg } 193*c8a55531SSebastian Grimberg } 194*c8a55531SSebastian Grimberg } 195*c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < A - a; aa++) { 196*c8a55531SSebastian Grimberg for (CeedInt jj = 0; jj < JJ / 4; jj++) storeu(&v[(a + aa) * J + j + jj * 4], vv[aa][jj]); 197*c8a55531SSebastian Grimberg } 198*c8a55531SSebastian Grimberg } 199*c8a55531SSebastian Grimberg // Column remainder 200*c8a55531SSebastian Grimberg CeedInt A_break = A % AA ? (A / AA) * AA : (A / AA - 1) * AA; 201*c8a55531SSebastian Grimberg // Blocks of 4 columns 202*c8a55531SSebastian Grimberg for (CeedInt j = (J / JJ) * JJ; j < J; j += 4) { 203*c8a55531SSebastian Grimberg // Blocks of 4 rows 204*c8a55531SSebastian Grimberg for (CeedInt a = 0; a < A_break; a += AA) { 205*c8a55531SSebastian Grimberg rtype vv[AA]; // Output tile to be held in registers 206*c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < AA; aa++) vv[aa] = loadu(&v[(a + aa) * J + j]); 207*c8a55531SSebastian Grimberg 208*c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) { 209*c8a55531SSebastian Grimberg rtype tqv; 210*c8a55531SSebastian Grimberg if (J - j == 1) { 211*c8a55531SSebastian Grimberg tqv = set(0.0, 0.0, 0.0, t[(j + 0) * t_stride_0 + b * t_stride_1]); 212*c8a55531SSebastian Grimberg } else if (J - j == 2) { 213*c8a55531SSebastian Grimberg tqv = set(0.0, 0.0, t[(j + 1) * t_stride_0 + b * t_stride_1], t[(j + 0) * t_stride_0 + b * t_stride_1]); 214*c8a55531SSebastian Grimberg } else if (J - 3 == j) { 215*c8a55531SSebastian Grimberg tqv = 216*c8a55531SSebastian Grimberg set(0.0, t[(j + 2) * t_stride_0 + b * t_stride_1], t[(j + 1) * t_stride_0 + b * t_stride_1], t[(j + 0) * t_stride_0 + b * t_stride_1]); 217*c8a55531SSebastian Grimberg } else { 218*c8a55531SSebastian Grimberg tqv = set(t[(j + 3) * t_stride_0 + b * t_stride_1], t[(j + 2) * t_stride_0 + b * t_stride_1], t[(j + 1) * t_stride_0 + b * t_stride_1], 219*c8a55531SSebastian Grimberg t[(j + 0) * t_stride_0 + b * t_stride_1]); 220*c8a55531SSebastian Grimberg } 221*c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < AA; aa++) { // unroll 222*c8a55531SSebastian Grimberg fmadd(vv[aa], tqv, set1(u[(a + aa) * B + b])); 223*c8a55531SSebastian Grimberg } 224*c8a55531SSebastian Grimberg } 225*c8a55531SSebastian Grimberg for (CeedInt aa = 0; aa < AA; aa++) storeu(&v[(a + aa) * J + j], vv[aa]); 226*c8a55531SSebastian Grimberg } 227*c8a55531SSebastian Grimberg } 228*c8a55531SSebastian Grimberg // Remainder of rows, all columns 229*c8a55531SSebastian Grimberg for (CeedInt b = 0; b < B; b++) { 230*c8a55531SSebastian Grimberg for (CeedInt j = (J / JJ) * JJ; j < J; j++) { 231*c8a55531SSebastian Grimberg CeedScalar tq = t[j * t_stride_0 + b * t_stride_1]; 232*c8a55531SSebastian Grimberg for (CeedInt a = A_break; a < A; a++) v[a * J + j] += tq * u[a * B + b]; 233*c8a55531SSebastian Grimberg } 234*c8a55531SSebastian Grimberg } 235*c8a55531SSebastian Grimberg return CEED_ERROR_SUCCESS; 236*c8a55531SSebastian Grimberg } 237*c8a55531SSebastian Grimberg 238*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 239*c8a55531SSebastian Grimberg // Tensor Contract - Common Sizes 240*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 241*c8a55531SSebastian Grimberg static int CeedTensorContract_Avx_Blocked_4_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 242*c8a55531SSebastian Grimberg CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 243*c8a55531SSebastian Grimberg return CeedTensorContract_Avx_Blocked(contract, A, B, C, J, t, t_mode, add, u, v, 4, 8); 244*c8a55531SSebastian Grimberg } 245*c8a55531SSebastian Grimberg static int CeedTensorContract_Avx_Remainder_8_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 246*c8a55531SSebastian Grimberg CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 247*c8a55531SSebastian Grimberg return CeedTensorContract_Avx_Remainder(contract, A, B, C, J, t, t_mode, add, u, v, 8, 8); 248*c8a55531SSebastian Grimberg } 249*c8a55531SSebastian Grimberg static int CeedTensorContract_Avx_Single_4_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 250*c8a55531SSebastian Grimberg CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 251*c8a55531SSebastian Grimberg return CeedTensorContract_Avx_Single(contract, A, B, C, J, t, t_mode, add, u, v, 4, 8); 252*c8a55531SSebastian Grimberg } 253*c8a55531SSebastian Grimberg 254*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 255*c8a55531SSebastian Grimberg // Tensor Contract Apply 256*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 257*c8a55531SSebastian Grimberg static int CeedTensorContractApply_Avx(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 258*c8a55531SSebastian Grimberg CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 259*c8a55531SSebastian Grimberg const CeedInt blk_size = 8; 260*c8a55531SSebastian Grimberg 261*c8a55531SSebastian Grimberg if (!add) { 262*c8a55531SSebastian Grimberg for (CeedInt q = 0; q < A * J * C; q++) v[q] = (CeedScalar)0.0; 263*c8a55531SSebastian Grimberg } 264*c8a55531SSebastian Grimberg 265*c8a55531SSebastian Grimberg if (C == 1) { 266*c8a55531SSebastian Grimberg // Serial C=1 Case 267*c8a55531SSebastian Grimberg CeedTensorContract_Avx_Single_4_8(contract, A, B, C, J, t, t_mode, true, u, v); 268*c8a55531SSebastian Grimberg } else { 269*c8a55531SSebastian Grimberg // Blocks of 8 columns 270*c8a55531SSebastian Grimberg if (C >= blk_size) CeedTensorContract_Avx_Blocked_4_8(contract, A, B, C, J, t, t_mode, true, u, v); 271*c8a55531SSebastian Grimberg // Remainder of columns 272*c8a55531SSebastian Grimberg if (C % blk_size) CeedTensorContract_Avx_Remainder_8_8(contract, A, B, C, J, t, t_mode, true, u, v); 273*c8a55531SSebastian Grimberg } 274*c8a55531SSebastian Grimberg 275*c8a55531SSebastian Grimberg return CEED_ERROR_SUCCESS; 276*c8a55531SSebastian Grimberg } 277*c8a55531SSebastian Grimberg 278*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 279*c8a55531SSebastian Grimberg // Tensor Contract Create 280*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 281*c8a55531SSebastian Grimberg int CeedTensorContractCreate_Avx(CeedBasis basis, CeedTensorContract contract) { 282*c8a55531SSebastian Grimberg Ceed ceed; 283*c8a55531SSebastian Grimberg CeedCallBackend(CeedTensorContractGetCeed(contract, &ceed)); 284*c8a55531SSebastian Grimberg 285*c8a55531SSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", CeedTensorContractApply_Avx)); 286*c8a55531SSebastian Grimberg 287*c8a55531SSebastian Grimberg return CEED_ERROR_SUCCESS; 288*c8a55531SSebastian Grimberg } 289*c8a55531SSebastian Grimberg 290*c8a55531SSebastian Grimberg //------------------------------------------------------------------------------ 291