1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <immintrin.h> 11 #include <stdbool.h> 12 13 #ifdef CEED_F64_H 14 #define rtype __m256d 15 #define loadu _mm256_loadu_pd 16 #define storeu _mm256_storeu_pd 17 #define set _mm256_set_pd 18 #define set1 _mm256_set1_pd 19 // c += a * b 20 #ifdef __FMA__ 21 #define fmadd(c, a, b) (c) = _mm256_fmadd_pd((a), (b), (c)) 22 #else 23 #define fmadd(c, a, b) (c) += _mm256_mul_pd((a), (b)) 24 #endif 25 #else 26 #define rtype __m128 27 #define loadu _mm_loadu_ps 28 #define storeu _mm_storeu_ps 29 #define set _mm_set_ps 30 #define set1 _mm_set1_ps 31 // c += a * b 32 #ifdef __FMA__ 33 #define fmadd(c, a, b) (c) = _mm_fmadd_ps((a), (b), (c)) 34 #else 35 #define fmadd(c, a, b) (c) += _mm_mul_ps((a), (b)) 36 #endif 37 #endif 38 39 //------------------------------------------------------------------------------ 40 // Blocked Tensor Contract 41 //------------------------------------------------------------------------------ 42 static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, 43 const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add, 44 const CeedScalar *restrict u, CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) { 45 CeedInt t_stride_0 = B, t_stride_1 = 1; 46 if (t_mode == CEED_TRANSPOSE) { 47 t_stride_0 = 1; 48 t_stride_1 = J; 49 } 50 51 for (CeedInt a = 0; a < A; a++) { 52 // Blocks of 4 rows 53 for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) { 54 for (CeedInt c = 0; c < (C / CC) * CC; c += CC) { 55 rtype vv[JJ][CC / 4]; // Output tile to be held in registers 56 for (CeedInt jj = 0; jj < JJ; jj++) { 57 for (CeedInt cc = 0; cc < CC / 4; cc++) vv[jj][cc] = loadu(&v[(a * J + j + jj) * C + c + cc * 4]); 58 } 59 60 for (CeedInt b = 0; b < B; b++) { 61 for (CeedInt jj = 0; jj < JJ; jj++) { // unroll 62 rtype tqv = set1(t[(j + jj) * t_stride_0 + b * t_stride_1]); 63 for (CeedInt cc = 0; cc < CC / 4; cc++) { // unroll 64 fmadd(vv[jj][cc], tqv, loadu(&u[(a * B + b) * C + c + cc * 4])); 65 } 66 } 67 } 68 for (CeedInt jj = 0; jj < JJ; jj++) { 69 for (CeedInt cc = 0; cc < CC / 4; cc++) storeu(&v[(a * J + j + jj) * C + c + cc * 4], vv[jj][cc]); 70 } 71 } 72 } 73 // Remainder of rows 74 CeedInt j = (J / JJ) * JJ; 75 if (j < J) { 76 for (CeedInt c = 0; c < (C / CC) * CC; c += CC) { 77 rtype vv[JJ][CC / 4]; // Output tile to be held in registers 78 for (CeedInt jj = 0; jj < J - j; jj++) { 79 for (CeedInt cc = 0; cc < CC / 4; cc++) vv[jj][cc] = loadu(&v[(a * J + j + jj) * C + c + cc * 4]); 80 } 81 82 for (CeedInt b = 0; b < B; b++) { 83 for (CeedInt jj = 0; jj < J - j; jj++) { // doesn't unroll 84 rtype tqv = set1(t[(j + jj) * t_stride_0 + b * t_stride_1]); 85 for (CeedInt cc = 0; cc < CC / 4; cc++) { // unroll 86 fmadd(vv[jj][cc], tqv, loadu(&u[(a * B + b) * C + c + cc * 4])); 87 } 88 } 89 } 90 for (CeedInt jj = 0; jj < J - j; jj++) { 91 for (CeedInt cc = 0; cc < CC / 4; cc++) storeu(&v[(a * J + j + jj) * C + c + cc * 4], vv[jj][cc]); 92 } 93 } 94 } 95 } 96 return CEED_ERROR_SUCCESS; 97 } 98 99 //------------------------------------------------------------------------------ 100 // Serial Tensor Contract Remainder 101 //------------------------------------------------------------------------------ 102 static inline int CeedTensorContract_Avx_Remainder(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, 103 const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add, 104 const CeedScalar *restrict u, CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) { 105 CeedInt t_stride_0 = B, t_stride_1 = 1; 106 if (t_mode == CEED_TRANSPOSE) { 107 t_stride_0 = 1; 108 t_stride_1 = J; 109 } 110 111 CeedInt J_break = J % JJ ? (J / JJ) * JJ : (J / JJ - 1) * JJ; 112 for (CeedInt a = 0; a < A; a++) { 113 // Blocks of 4 columns 114 for (CeedInt c = (C / CC) * CC; c < C; c += 4) { 115 // Blocks of 4 rows 116 for (CeedInt j = 0; j < J_break; j += JJ) { 117 rtype vv[JJ]; // Output tile to be held in registers 118 for (CeedInt jj = 0; jj < JJ; jj++) vv[jj] = loadu(&v[(a * J + j + jj) * C + c]); 119 120 for (CeedInt b = 0; b < B; b++) { 121 rtype tqu; 122 if (C - c == 1) tqu = set(0.0, 0.0, 0.0, u[(a * B + b) * C + c + 0]); 123 else if (C - c == 2) tqu = set(0.0, 0.0, u[(a * B + b) * C + c + 1], u[(a * B + b) * C + c + 0]); 124 else if (C - c == 3) tqu = set(0.0, u[(a * B + b) * C + c + 2], u[(a * B + b) * C + c + 1], u[(a * B + b) * C + c + 0]); 125 else tqu = loadu(&u[(a * B + b) * C + c]); 126 for (CeedInt jj = 0; jj < JJ; jj++) { // unroll 127 fmadd(vv[jj], tqu, set1(t[(j + jj) * t_stride_0 + b * t_stride_1])); 128 } 129 } 130 for (CeedInt jj = 0; jj < JJ; jj++) storeu(&v[(a * J + j + jj) * C + c], vv[jj]); 131 } 132 } 133 // Remainder of rows, all columns 134 for (CeedInt j = J_break; j < J; j++) { 135 for (CeedInt b = 0; b < B; b++) { 136 CeedScalar tq = t[j * t_stride_0 + b * t_stride_1]; 137 for (CeedInt c = (C / CC) * CC; c < C; c++) v[(a * J + j) * C + c] += tq * u[(a * B + b) * C + c]; 138 } 139 } 140 } 141 return CEED_ERROR_SUCCESS; 142 } 143 144 //------------------------------------------------------------------------------ 145 // Serial Tensor Contract C=1 146 //------------------------------------------------------------------------------ 147 static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 148 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v, 149 const CeedInt AA, const CeedInt JJ) { 150 CeedInt t_stride_0 = B, t_stride_1 = 1; 151 if (t_mode == CEED_TRANSPOSE) { 152 t_stride_0 = 1; 153 t_stride_1 = J; 154 } 155 156 // Blocks of 4 rows 157 for (CeedInt a = 0; a < (A / AA) * AA; a += AA) { 158 for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) { 159 rtype vv[AA][JJ / 4]; // Output tile to be held in registers 160 for (CeedInt aa = 0; aa < AA; aa++) { 161 for (CeedInt jj = 0; jj < JJ / 4; jj++) vv[aa][jj] = loadu(&v[(a + aa) * J + j + jj * 4]); 162 } 163 164 for (CeedInt b = 0; b < B; b++) { 165 for (CeedInt jj = 0; jj < JJ / 4; jj++) { // unroll 166 rtype tqv = set(t[(j + jj * 4 + 3) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 2) * t_stride_0 + b * t_stride_1], 167 t[(j + jj * 4 + 1) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 0) * t_stride_0 + b * t_stride_1]); 168 for (CeedInt aa = 0; aa < AA; aa++) { // unroll 169 fmadd(vv[aa][jj], tqv, set1(u[(a + aa) * B + b])); 170 } 171 } 172 } 173 for (CeedInt aa = 0; aa < AA; aa++) { 174 for (CeedInt jj = 0; jj < JJ / 4; jj++) storeu(&v[(a + aa) * J + j + jj * 4], vv[aa][jj]); 175 } 176 } 177 } 178 // Remainder of rows 179 CeedInt a = (A / AA) * AA; 180 for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) { 181 rtype vv[AA][JJ / 4]; // Output tile to be held in registers 182 for (CeedInt aa = 0; aa < A - a; aa++) { 183 for (CeedInt jj = 0; jj < JJ / 4; jj++) vv[aa][jj] = loadu(&v[(a + aa) * J + j + jj * 4]); 184 } 185 186 for (CeedInt b = 0; b < B; b++) { 187 for (CeedInt jj = 0; jj < JJ / 4; jj++) { // unroll 188 rtype tqv = set(t[(j + jj * 4 + 3) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 2) * t_stride_0 + b * t_stride_1], 189 t[(j + jj * 4 + 1) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 0) * t_stride_0 + b * t_stride_1]); 190 for (CeedInt aa = 0; aa < A - a; aa++) { // unroll 191 fmadd(vv[aa][jj], tqv, set1(u[(a + aa) * B + b])); 192 } 193 } 194 } 195 for (CeedInt aa = 0; aa < A - a; aa++) { 196 for (CeedInt jj = 0; jj < JJ / 4; jj++) storeu(&v[(a + aa) * J + j + jj * 4], vv[aa][jj]); 197 } 198 } 199 // Column remainder 200 CeedInt A_break = A % AA ? (A / AA) * AA : (A / AA - 1) * AA; 201 // Blocks of 4 columns 202 for (CeedInt j = (J / JJ) * JJ; j < J; j += 4) { 203 // Blocks of 4 rows 204 for (CeedInt a = 0; a < A_break; a += AA) { 205 rtype vv[AA]; // Output tile to be held in registers 206 for (CeedInt aa = 0; aa < AA; aa++) vv[aa] = loadu(&v[(a + aa) * J + j]); 207 208 for (CeedInt b = 0; b < B; b++) { 209 rtype tqv; 210 if (J - j == 1) { 211 tqv = set(0.0, 0.0, 0.0, t[(j + 0) * t_stride_0 + b * t_stride_1]); 212 } else if (J - j == 2) { 213 tqv = set(0.0, 0.0, t[(j + 1) * t_stride_0 + b * t_stride_1], t[(j + 0) * t_stride_0 + b * t_stride_1]); 214 } else if (J - 3 == j) { 215 tqv = 216 set(0.0, t[(j + 2) * t_stride_0 + b * t_stride_1], t[(j + 1) * t_stride_0 + b * t_stride_1], t[(j + 0) * t_stride_0 + b * t_stride_1]); 217 } else { 218 tqv = set(t[(j + 3) * t_stride_0 + b * t_stride_1], t[(j + 2) * t_stride_0 + b * t_stride_1], t[(j + 1) * t_stride_0 + b * t_stride_1], 219 t[(j + 0) * t_stride_0 + b * t_stride_1]); 220 } 221 for (CeedInt aa = 0; aa < AA; aa++) { // unroll 222 fmadd(vv[aa], tqv, set1(u[(a + aa) * B + b])); 223 } 224 } 225 for (CeedInt aa = 0; aa < AA; aa++) storeu(&v[(a + aa) * J + j], vv[aa]); 226 } 227 } 228 // Remainder of rows, all columns 229 for (CeedInt b = 0; b < B; b++) { 230 for (CeedInt j = (J / JJ) * JJ; j < J; j++) { 231 CeedScalar tq = t[j * t_stride_0 + b * t_stride_1]; 232 for (CeedInt a = A_break; a < A; a++) v[a * J + j] += tq * u[a * B + b]; 233 } 234 } 235 return CEED_ERROR_SUCCESS; 236 } 237 238 //------------------------------------------------------------------------------ 239 // Tensor Contract - Common Sizes 240 //------------------------------------------------------------------------------ 241 static int CeedTensorContract_Avx_Blocked_4_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 242 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 243 return CeedTensorContract_Avx_Blocked(contract, A, B, C, J, t, t_mode, add, u, v, 4, 8); 244 } 245 static int CeedTensorContract_Avx_Remainder_8_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 246 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 247 return CeedTensorContract_Avx_Remainder(contract, A, B, C, J, t, t_mode, add, u, v, 8, 8); 248 } 249 static int CeedTensorContract_Avx_Single_4_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 250 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 251 return CeedTensorContract_Avx_Single(contract, A, B, C, J, t, t_mode, add, u, v, 4, 8); 252 } 253 254 //------------------------------------------------------------------------------ 255 // Tensor Contract Apply 256 //------------------------------------------------------------------------------ 257 static int CeedTensorContractApply_Avx(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 258 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 259 const CeedInt blk_size = 8; 260 261 if (!add) { 262 for (CeedInt q = 0; q < A * J * C; q++) v[q] = (CeedScalar)0.0; 263 } 264 265 if (C == 1) { 266 // Serial C=1 Case 267 CeedTensorContract_Avx_Single_4_8(contract, A, B, C, J, t, t_mode, true, u, v); 268 } else { 269 // Blocks of 8 columns 270 if (C >= blk_size) CeedTensorContract_Avx_Blocked_4_8(contract, A, B, C, J, t, t_mode, true, u, v); 271 // Remainder of columns 272 if (C % blk_size) CeedTensorContract_Avx_Remainder_8_8(contract, A, B, C, J, t, t_mode, true, u, v); 273 } 274 275 return CEED_ERROR_SUCCESS; 276 } 277 278 //------------------------------------------------------------------------------ 279 // Tensor Contract Create 280 //------------------------------------------------------------------------------ 281 int CeedTensorContractCreate_Avx(CeedBasis basis, CeedTensorContract contract) { 282 Ceed ceed; 283 CeedCallBackend(CeedTensorContractGetCeed(contract, &ceed)); 284 285 CeedCallBackend(CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", CeedTensorContractApply_Avx)); 286 287 return CEED_ERROR_SUCCESS; 288 } 289 290 //------------------------------------------------------------------------------ 291