1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <immintrin.h> 11 #include <stdbool.h> 12 13 #ifdef CEED_F64_H 14 #define rtype __m256d 15 #define loadu _mm256_loadu_pd 16 #define storeu _mm256_storeu_pd 17 #define set _mm256_set_pd 18 #define set1 _mm256_set1_pd 19 // c += a * b 20 #ifdef __FMA__ 21 #define fmadd(c, a, b) (c) = _mm256_fmadd_pd((a), (b), (c)) 22 #else 23 #define fmadd(c, a, b) (c) += _mm256_mul_pd((a), (b)) 24 #endif 25 #else 26 #define rtype __m128 27 #define loadu _mm_loadu_ps 28 #define storeu _mm_storeu_ps 29 #define set _mm_set_ps 30 #define set1 _mm_set1_ps 31 // c += a * b 32 #ifdef __FMA__ 33 #define fmadd(c, a, b) (c) = _mm_fmadd_ps((a), (b), (c)) 34 #else 35 #define fmadd(c, a, b) (c) += _mm_mul_ps((a), (b)) 36 #endif 37 #endif 38 39 //------------------------------------------------------------------------------ 40 // Blocked Tensor Contract 41 //------------------------------------------------------------------------------ 42 static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, 43 const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add, 44 const CeedScalar *restrict u, CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) { 45 CeedInt t_stride_0 = B, t_stride_1 = 1; 46 47 if (t_mode == CEED_TRANSPOSE) { 48 t_stride_0 = 1; 49 t_stride_1 = J; 50 } 51 52 for (CeedInt a = 0; a < A; a++) { 53 // Blocks of 4 rows 54 for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) { 55 for (CeedInt c = 0; c < (C / CC) * CC; c += CC) { 56 rtype vv[JJ][CC / 4]; // Output tile to be held in registers 57 for (CeedInt jj = 0; jj < JJ; jj++) { 58 for (CeedInt cc = 0; cc < CC / 4; cc++) vv[jj][cc] = loadu(&v[(a * J + j + jj) * C + c + cc * 4]); 59 } 60 for (CeedInt b = 0; b < B; b++) { 61 for (CeedInt jj = 0; jj < JJ; jj++) { // unroll 62 rtype tqv = set1(t[(j + jj) * t_stride_0 + b * t_stride_1]); 63 for (CeedInt cc = 0; cc < CC / 4; cc++) { // unroll 64 fmadd(vv[jj][cc], tqv, loadu(&u[(a * B + b) * C + c + cc * 4])); 65 } 66 } 67 } 68 for (CeedInt jj = 0; jj < JJ; jj++) { 69 for (CeedInt cc = 0; cc < CC / 4; cc++) storeu(&v[(a * J + j + jj) * C + c + cc * 4], vv[jj][cc]); 70 } 71 } 72 } 73 // Remainder of rows 74 const CeedInt j = (J / JJ) * JJ; 75 76 if (j < J) { 77 for (CeedInt c = 0; c < (C / CC) * CC; c += CC) { 78 rtype vv[JJ][CC / 4]; // Output tile to be held in registers 79 80 for (CeedInt jj = 0; jj < J - j; jj++) { 81 for (CeedInt cc = 0; cc < CC / 4; cc++) vv[jj][cc] = loadu(&v[(a * J + j + jj) * C + c + cc * 4]); 82 } 83 for (CeedInt b = 0; b < B; b++) { 84 for (CeedInt jj = 0; jj < J - j; jj++) { // doesn't unroll 85 rtype tqv = set1(t[(j + jj) * t_stride_0 + b * t_stride_1]); 86 87 for (CeedInt cc = 0; cc < CC / 4; cc++) { // unroll 88 fmadd(vv[jj][cc], tqv, loadu(&u[(a * B + b) * C + c + cc * 4])); 89 } 90 } 91 } 92 for (CeedInt jj = 0; jj < J - j; jj++) { 93 for (CeedInt cc = 0; cc < CC / 4; cc++) storeu(&v[(a * J + j + jj) * C + c + cc * 4], vv[jj][cc]); 94 } 95 } 96 } 97 } 98 return CEED_ERROR_SUCCESS; 99 } 100 101 //------------------------------------------------------------------------------ 102 // Serial Tensor Contract Remainder 103 //------------------------------------------------------------------------------ 104 static inline int CeedTensorContract_Avx_Remainder(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, 105 const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add, 106 const CeedScalar *restrict u, CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) { 107 CeedInt t_stride_0 = B, t_stride_1 = 1; 108 109 if (t_mode == CEED_TRANSPOSE) { 110 t_stride_0 = 1; 111 t_stride_1 = J; 112 } 113 114 const CeedInt J_break = J % JJ ? (J / JJ) * JJ : (J / JJ - 1) * JJ; 115 116 for (CeedInt a = 0; a < A; a++) { 117 // Blocks of 4 columns 118 for (CeedInt c = (C / CC) * CC; c < C; c += 4) { 119 // Blocks of 4 rows 120 for (CeedInt j = 0; j < J_break; j += JJ) { 121 rtype vv[JJ]; // Output tile to be held in registers 122 123 for (CeedInt jj = 0; jj < JJ; jj++) vv[jj] = loadu(&v[(a * J + j + jj) * C + c]); 124 for (CeedInt b = 0; b < B; b++) { 125 rtype tqu; 126 127 if (C - c == 1) tqu = set(0.0, 0.0, 0.0, u[(a * B + b) * C + c + 0]); 128 else if (C - c == 2) tqu = set(0.0, 0.0, u[(a * B + b) * C + c + 1], u[(a * B + b) * C + c + 0]); 129 else if (C - c == 3) tqu = set(0.0, u[(a * B + b) * C + c + 2], u[(a * B + b) * C + c + 1], u[(a * B + b) * C + c + 0]); 130 else tqu = loadu(&u[(a * B + b) * C + c]); 131 for (CeedInt jj = 0; jj < JJ; jj++) { // unroll 132 fmadd(vv[jj], tqu, set1(t[(j + jj) * t_stride_0 + b * t_stride_1])); 133 } 134 } 135 for (CeedInt jj = 0; jj < JJ; jj++) storeu(&v[(a * J + j + jj) * C + c], vv[jj]); 136 } 137 } 138 // Remainder of rows, all columns 139 for (CeedInt j = J_break; j < J; j++) { 140 for (CeedInt b = 0; b < B; b++) { 141 const CeedScalar tq = t[j * t_stride_0 + b * t_stride_1]; 142 143 for (CeedInt c = (C / CC) * CC; c < C; c++) v[(a * J + j) * C + c] += tq * u[(a * B + b) * C + c]; 144 } 145 } 146 } 147 return CEED_ERROR_SUCCESS; 148 } 149 150 //------------------------------------------------------------------------------ 151 // Serial Tensor Contract C=1 152 //------------------------------------------------------------------------------ 153 static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 154 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v, 155 const CeedInt AA, const CeedInt JJ) { 156 CeedInt t_stride_0 = B, t_stride_1 = 1; 157 158 if (t_mode == CEED_TRANSPOSE) { 159 t_stride_0 = 1; 160 t_stride_1 = J; 161 } 162 163 // Blocks of 4 rows 164 for (CeedInt a = 0; a < (A / AA) * AA; a += AA) { 165 for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) { 166 rtype vv[AA][JJ / 4]; // Output tile to be held in registers 167 168 for (CeedInt aa = 0; aa < AA; aa++) { 169 for (CeedInt jj = 0; jj < JJ / 4; jj++) vv[aa][jj] = loadu(&v[(a + aa) * J + j + jj * 4]); 170 } 171 for (CeedInt b = 0; b < B; b++) { 172 for (CeedInt jj = 0; jj < JJ / 4; jj++) { // unroll 173 rtype tqv = set(t[(j + jj * 4 + 3) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 2) * t_stride_0 + b * t_stride_1], 174 t[(j + jj * 4 + 1) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 0) * t_stride_0 + b * t_stride_1]); 175 176 for (CeedInt aa = 0; aa < AA; aa++) { // unroll 177 fmadd(vv[aa][jj], tqv, set1(u[(a + aa) * B + b])); 178 } 179 } 180 } 181 for (CeedInt aa = 0; aa < AA; aa++) { 182 for (CeedInt jj = 0; jj < JJ / 4; jj++) storeu(&v[(a + aa) * J + j + jj * 4], vv[aa][jj]); 183 } 184 } 185 } 186 // Remainder of rows 187 const CeedInt a = (A / AA) * AA; 188 189 for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) { 190 rtype vv[AA][JJ / 4]; // Output tile to be held in registers 191 192 for (CeedInt aa = 0; aa < A - a; aa++) { 193 for (CeedInt jj = 0; jj < JJ / 4; jj++) vv[aa][jj] = loadu(&v[(a + aa) * J + j + jj * 4]); 194 } 195 for (CeedInt b = 0; b < B; b++) { 196 for (CeedInt jj = 0; jj < JJ / 4; jj++) { // unroll 197 rtype tqv = set(t[(j + jj * 4 + 3) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 2) * t_stride_0 + b * t_stride_1], 198 t[(j + jj * 4 + 1) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 0) * t_stride_0 + b * t_stride_1]); 199 200 for (CeedInt aa = 0; aa < A - a; aa++) { // unroll 201 fmadd(vv[aa][jj], tqv, set1(u[(a + aa) * B + b])); 202 } 203 } 204 } 205 for (CeedInt aa = 0; aa < A - a; aa++) { 206 for (CeedInt jj = 0; jj < JJ / 4; jj++) storeu(&v[(a + aa) * J + j + jj * 4], vv[aa][jj]); 207 } 208 } 209 // Column remainder 210 const CeedInt A_break = A % AA ? (A / AA) * AA : (A / AA - 1) * AA; 211 212 // Blocks of 4 columns 213 for (CeedInt j = (J / JJ) * JJ; j < J; j += 4) { 214 // Blocks of 4 rows 215 for (CeedInt a = 0; a < A_break; a += AA) { 216 rtype vv[AA]; // Output tile to be held in registers 217 218 for (CeedInt aa = 0; aa < AA; aa++) vv[aa] = loadu(&v[(a + aa) * J + j]); 219 for (CeedInt b = 0; b < B; b++) { 220 rtype tqv; 221 222 if (J - j == 1) { 223 tqv = set(0.0, 0.0, 0.0, t[(j + 0) * t_stride_0 + b * t_stride_1]); 224 } else if (J - j == 2) { 225 tqv = set(0.0, 0.0, t[(j + 1) * t_stride_0 + b * t_stride_1], t[(j + 0) * t_stride_0 + b * t_stride_1]); 226 } else if (J - 3 == j) { 227 tqv = 228 set(0.0, t[(j + 2) * t_stride_0 + b * t_stride_1], t[(j + 1) * t_stride_0 + b * t_stride_1], t[(j + 0) * t_stride_0 + b * t_stride_1]); 229 } else { 230 tqv = set(t[(j + 3) * t_stride_0 + b * t_stride_1], t[(j + 2) * t_stride_0 + b * t_stride_1], t[(j + 1) * t_stride_0 + b * t_stride_1], 231 t[(j + 0) * t_stride_0 + b * t_stride_1]); 232 } 233 for (CeedInt aa = 0; aa < AA; aa++) { // unroll 234 fmadd(vv[aa], tqv, set1(u[(a + aa) * B + b])); 235 } 236 } 237 for (CeedInt aa = 0; aa < AA; aa++) storeu(&v[(a + aa) * J + j], vv[aa]); 238 } 239 } 240 // Remainder of rows, all columns 241 for (CeedInt b = 0; b < B; b++) { 242 for (CeedInt j = (J / JJ) * JJ; j < J; j++) { 243 const CeedScalar tq = t[j * t_stride_0 + b * t_stride_1]; 244 245 for (CeedInt a = A_break; a < A; a++) v[a * J + j] += tq * u[a * B + b]; 246 } 247 } 248 return CEED_ERROR_SUCCESS; 249 } 250 251 //------------------------------------------------------------------------------ 252 // Tensor Contract - Common Sizes 253 //------------------------------------------------------------------------------ 254 static int CeedTensorContract_Avx_Blocked_4_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 255 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 256 return CeedTensorContract_Avx_Blocked(contract, A, B, C, J, t, t_mode, add, u, v, 4, 8); 257 } 258 static int CeedTensorContract_Avx_Remainder_8_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 259 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 260 return CeedTensorContract_Avx_Remainder(contract, A, B, C, J, t, t_mode, add, u, v, 8, 8); 261 } 262 static int CeedTensorContract_Avx_Single_4_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 263 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 264 return CeedTensorContract_Avx_Single(contract, A, B, C, J, t, t_mode, add, u, v, 4, 8); 265 } 266 267 //------------------------------------------------------------------------------ 268 // Tensor Contract Apply 269 //------------------------------------------------------------------------------ 270 static int CeedTensorContractApply_Avx(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 271 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 272 const CeedInt blk_size = 8; 273 274 if (!add) { 275 for (CeedInt q = 0; q < A * J * C; q++) v[q] = (CeedScalar)0.0; 276 } 277 278 if (C == 1) { 279 // Serial C=1 Case 280 CeedTensorContract_Avx_Single_4_8(contract, A, B, C, J, t, t_mode, true, u, v); 281 } else { 282 // Blocks of 8 columns 283 if (C >= blk_size) CeedTensorContract_Avx_Blocked_4_8(contract, A, B, C, J, t, t_mode, true, u, v); 284 // Remainder of columns 285 if (C % blk_size) CeedTensorContract_Avx_Remainder_8_8(contract, A, B, C, J, t, t_mode, true, u, v); 286 } 287 return CEED_ERROR_SUCCESS; 288 } 289 290 //------------------------------------------------------------------------------ 291 // Tensor Contract Create 292 //------------------------------------------------------------------------------ 293 int CeedTensorContractCreate_Avx(CeedTensorContract contract) { 294 CeedCallBackend(CeedSetBackendFunction(CeedTensorContractReturnCeed(contract), "TensorContract", contract, "Apply", CeedTensorContractApply_Avx)); 295 return CEED_ERROR_SUCCESS; 296 } 297 298 //------------------------------------------------------------------------------ 299