1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC. 2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707. 3 // All Rights reserved. See files LICENSE and NOTICE for details. 4 // 5 // This file is part of CEED, a collection of benchmarks, miniapps, software 6 // libraries and APIs for efficient high-order finite element and spectral 7 // element discretizations for exascale applications. For more information and 8 // source code availability see http://github.com/ceed. 9 // 10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 11 // a collaborative effort of two U.S. Department of Energy organizations (Office 12 // of Science and the National Nuclear Security Administration) responsible for 13 // the planning and preparation of a capable exascale ecosystem, including 14 // software, applications, hardware, advanced system engineering and early 15 // testbed platforms, in support of the nation's exascale computing imperative. 16 17 #include <ceed.h> 18 #include <ceed-backend.h> 19 #include <immintrin.h> 20 #include <stdbool.h> 21 #include "ceed-avx.h" 22 23 // c += a * b 24 #ifdef __FMA__ 25 # define fmadd(c,a,b) (c) = _mm256_fmadd_pd((a), (b), (c)) 26 #else 27 # define fmadd(c,a,b) (c) += _mm256_mul_pd((a), (b)) 28 #endif 29 30 //------------------------------------------------------------------------------ 31 // Blocked Tensor Contract 32 //------------------------------------------------------------------------------ 33 static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract, 34 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 35 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 36 CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) { 37 CeedInt tstride0 = B, tstride1 = 1; 38 if (tmode == CEED_TRANSPOSE) { 39 tstride0 = 1; tstride1 = J; 40 } 41 42 for (CeedInt a=0; a<A; a++) { 43 // Blocks of 4 rows 44 for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) { 45 for (CeedInt c=0; c<(C/CC)*CC; c+=CC) { 46 __m256d vv[JJ][CC/4]; // Output tile to be held in registers 47 for (CeedInt jj=0; jj<JJ; jj++) 48 for (CeedInt cc=0; cc<CC/4; cc++) 49 vv[jj][cc] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c+cc*4]); 50 51 for (CeedInt b=0; b<B; b++) { 52 for (CeedInt jj=0; jj<JJ; jj++) { // unroll 53 __m256d tqv = _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1]); 54 for (CeedInt cc=0; cc<CC/4; cc++) // unroll 55 fmadd(vv[jj][cc], tqv, _mm256_loadu_pd(&u[(a*B+b)*C+c+cc*4])); 56 } 57 } 58 for (CeedInt jj=0; jj<JJ; jj++) 59 for (CeedInt cc=0; cc<CC/4; cc++) 60 _mm256_storeu_pd(&v[(a*J+j+jj)*C+c+cc*4], vv[jj][cc]); 61 } 62 } 63 // Remainder of rows 64 CeedInt j=(J/JJ)*JJ; 65 if (j < J) { 66 for (CeedInt c=0; c<(C/CC)*CC; c+=CC) { 67 __m256d vv[JJ][CC/4]; // Output tile to be held in registers 68 for (CeedInt jj=0; jj<J-j; jj++) 69 for (CeedInt cc=0; cc<CC/4; cc++) 70 vv[jj][cc] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c+cc*4]); 71 72 for (CeedInt b=0; b<B; b++) { 73 for (CeedInt jj=0; jj<J-j; jj++) { // doesn't unroll 74 __m256d tqv = _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1]); 75 for (CeedInt cc=0; cc<CC/4; cc++) // unroll 76 fmadd(vv[jj][cc], tqv, _mm256_loadu_pd(&u[(a*B+b)*C+c+cc*4])); 77 } 78 } 79 for (CeedInt jj=0; jj<J-j; jj++) 80 for (CeedInt cc=0; cc<CC/4; cc++) 81 _mm256_storeu_pd(&v[(a*J+j+jj)*C+c+cc*4], vv[jj][cc]); 82 } 83 } 84 } 85 return CEED_ERROR_SUCCESS; 86 } 87 88 //------------------------------------------------------------------------------ 89 // Serial Tensor Contract Remainder 90 //------------------------------------------------------------------------------ 91 static inline int CeedTensorContract_Avx_Remainder(CeedTensorContract contract, 92 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 93 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 94 CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) { 95 CeedInt tstride0 = B, tstride1 = 1; 96 if (tmode == CEED_TRANSPOSE) { 97 tstride0 = 1; tstride1 = J; 98 } 99 100 CeedInt Jbreak = J%JJ ? (J/JJ)*JJ : (J/JJ-1)*JJ; 101 for (CeedInt a=0; a<A; a++) { 102 // Blocks of 4 columns 103 for (CeedInt c = (C/CC)*CC; c<C; c+=4) { 104 // Blocks of 4 rows 105 for (CeedInt j=0; j<Jbreak; j+=JJ) { 106 __m256d vv[JJ]; // Output tile to be held in registers 107 for (CeedInt jj=0; jj<JJ; jj++) 108 vv[jj] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c]); 109 110 for (CeedInt b=0; b<B; b++) { 111 __m256d tqu; 112 if (C-c == 1) 113 tqu = _mm256_set_pd(0.0, 0.0, 0.0, u[(a*B+b)*C+c+0]); 114 else if (C-c == 2) 115 tqu = _mm256_set_pd(0.0, 0.0, u[(a*B+b)*C+c+1], 116 u[(a*B+b)*C+c+0]); 117 else if (C-c == 3) 118 tqu = _mm256_set_pd(0.0, u[(a*B+b)*C+c+2], u[(a*B+b)*C+c+1], 119 u[(a*B+b)*C+c+0]); 120 else 121 tqu = _mm256_loadu_pd(&u[(a*B+b)*C+c]); 122 for (CeedInt jj=0; jj<JJ; jj++) // unroll 123 fmadd(vv[jj], tqu, _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1])); 124 } 125 for (CeedInt jj=0; jj<JJ; jj++) 126 _mm256_storeu_pd(&v[(a*J+j+jj)*C+c], vv[jj]); 127 } 128 } 129 // Remainder of rows, all columns 130 for (CeedInt j=Jbreak; j<J; j++) 131 for (CeedInt b=0; b<B; b++) { 132 CeedScalar tq = t[j*tstride0 + b*tstride1]; 133 for (CeedInt c=(C/CC)*CC; c<C; c++) 134 v[(a*J+j)*C+c] += tq * u[(a*B+b)*C+c]; 135 } 136 } 137 return CEED_ERROR_SUCCESS; 138 } 139 140 //------------------------------------------------------------------------------ 141 // Serial Tensor Contract C=1 142 //------------------------------------------------------------------------------ 143 static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract, 144 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 145 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 146 CeedScalar *restrict v, const CeedInt AA, const CeedInt JJ) { 147 CeedInt tstride0 = B, tstride1 = 1; 148 if (tmode == CEED_TRANSPOSE) { 149 tstride0 = 1; tstride1 = J; 150 } 151 152 // Blocks of 4 rows 153 for (CeedInt a=0; a<(A/AA)*AA; a+=AA) { 154 for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) { 155 __m256d vv[AA][JJ/4]; // Output tile to be held in registers 156 for (CeedInt aa=0; aa<AA; aa++) 157 for (CeedInt jj=0; jj<JJ/4; jj++) 158 vv[aa][jj] = _mm256_loadu_pd(&v[(a+aa)*J+j+jj*4]); 159 160 for (CeedInt b=0; b<B; b++) { 161 for (CeedInt jj=0; jj<JJ/4; jj++) { // unroll 162 __m256d tqv = _mm256_set_pd(t[(j+jj*4+3)*tstride0 + b*tstride1], 163 t[(j+jj*4+2)*tstride0 + b*tstride1], 164 t[(j+jj*4+1)*tstride0 + b*tstride1], 165 t[(j+jj*4+0)*tstride0 + b*tstride1]); 166 for (CeedInt aa=0; aa<AA; aa++) // unroll 167 fmadd(vv[aa][jj], tqv, _mm256_set1_pd(u[(a+aa)*B+b])); 168 } 169 } 170 for (CeedInt aa=0; aa<AA; aa++) 171 for (CeedInt jj=0; jj<JJ/4; jj++) 172 _mm256_storeu_pd(&v[(a+aa)*J+j+jj*4], vv[aa][jj]); 173 } 174 } 175 // Remainder of rows 176 CeedInt a=(A/AA)*AA; 177 for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) { 178 __m256d vv[AA][JJ/4]; // Output tile to be held in registers 179 for (CeedInt aa=0; aa<A-a; aa++) 180 for (CeedInt jj=0; jj<JJ/4; jj++) 181 vv[aa][jj] = _mm256_loadu_pd(&v[(a+aa)*J+j+jj*4]); 182 183 for (CeedInt b=0; b<B; b++) { 184 for (CeedInt jj=0; jj<JJ/4; jj++) { // unroll 185 __m256d tqv = _mm256_set_pd(t[(j+jj*4+3)*tstride0 + b*tstride1], 186 t[(j+jj*4+2)*tstride0 + b*tstride1], 187 t[(j+jj*4+1)*tstride0 + b*tstride1], 188 t[(j+jj*4+0)*tstride0 + b*tstride1]); 189 for (CeedInt aa=0; aa<A-a; aa++) // unroll 190 fmadd(vv[aa][jj], tqv, _mm256_set1_pd(u[(a+aa)*B+b])); 191 } 192 } 193 for (CeedInt aa=0; aa<A-a; aa++) 194 for (CeedInt jj=0; jj<JJ/4; jj++) 195 _mm256_storeu_pd(&v[(a+aa)*J+j+jj*4], vv[aa][jj]); 196 } 197 // Column remainder 198 CeedInt Abreak = A%AA ? (A/AA)*AA : (A/AA-1)*AA; 199 // Blocks of 4 columns 200 for (CeedInt j = (J/JJ)*JJ; j<J; j+=4) { 201 // Blocks of 4 rows 202 for (CeedInt a=0; a<Abreak; a+=AA) { 203 __m256d vv[AA]; // Output tile to be held in registers 204 for (CeedInt aa=0; aa<AA; aa++) 205 vv[aa] = _mm256_loadu_pd(&v[(a+aa)*J+j]); 206 207 for (CeedInt b=0; b<B; b++) { 208 __m256d tqv; 209 if (J-j == 1) 210 tqv = _mm256_set_pd(0.0, 0.0, 0.0, t[(j+0)*tstride0 + b*tstride1]); 211 else if (J-j == 2) 212 tqv = _mm256_set_pd(0.0, 0.0, t[(j+1)*tstride0 + b*tstride1], 213 t[(j+0)*tstride0 + b*tstride1]); 214 else if (J-3 == j) 215 tqv = _mm256_set_pd(0.0, t[(j+2)*tstride0 + b*tstride1], 216 t[(j+1)*tstride0 + b*tstride1], 217 t[(j+0)*tstride0 + b*tstride1]); 218 else 219 tqv = _mm256_set_pd(t[(j+3)*tstride0 + b*tstride1], 220 t[(j+2)*tstride0 + b*tstride1], 221 t[(j+1)*tstride0 + b*tstride1], 222 t[(j+0)*tstride0 + b*tstride1]); 223 for (CeedInt aa=0; aa<AA; aa++) // unroll 224 fmadd(vv[aa], tqv, _mm256_set1_pd(u[(a+aa)*B+b])); 225 } 226 for (CeedInt aa=0; aa<AA; aa++) 227 _mm256_storeu_pd(&v[(a+aa)*J+j], vv[aa]); 228 } 229 } 230 // Remainder of rows, all columns 231 for (CeedInt b=0; b<B; b++) { 232 for (CeedInt j=(J/JJ)*JJ; j<J; j++) { 233 CeedScalar tq = t[j*tstride0 + b*tstride1]; 234 for (CeedInt a=Abreak; a<A; a++) 235 v[a*J+j] += tq * u[a*B+b]; 236 } 237 } 238 return CEED_ERROR_SUCCESS; 239 } 240 241 //------------------------------------------------------------------------------ 242 // Tensor Contract - Common Sizes 243 //------------------------------------------------------------------------------ 244 static int CeedTensorContract_Avx_Blocked_4_8(CeedTensorContract contract, 245 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 246 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 247 CeedScalar *restrict v) { 248 return CeedTensorContract_Avx_Blocked(contract, A, B, C, J, t, tmode, Add, u, 249 v, 4, 8); 250 } 251 static int CeedTensorContract_Avx_Remainder_8_8(CeedTensorContract contract, 252 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 253 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 254 CeedScalar *restrict v) { 255 return CeedTensorContract_Avx_Remainder(contract, A, B, C, J, t, tmode, Add, 256 u, v, 8, 8); 257 } 258 static int CeedTensorContract_Avx_Single_4_8(CeedTensorContract contract, 259 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 260 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 261 CeedScalar *restrict v) { 262 return CeedTensorContract_Avx_Single(contract, A, B, C, J, t, tmode, Add, u, 263 v, 4, 8); 264 } 265 266 //------------------------------------------------------------------------------ 267 // Tensor Contract Apply 268 //------------------------------------------------------------------------------ 269 static int CeedTensorContractApply_Avx(CeedTensorContract contract, CeedInt A, 270 CeedInt B, CeedInt C, CeedInt J, 271 const CeedScalar *restrict t, 272 CeedTransposeMode tmode, 273 const CeedInt Add, 274 const CeedScalar *restrict u, 275 CeedScalar *restrict v) { 276 const CeedInt blksize = 8; 277 278 if (!Add) 279 for (CeedInt q=0; q<A*J*C; q++) 280 v[q] = (CeedScalar) 0.0; 281 282 if (C == 1) { 283 // Serial C=1 Case 284 CeedTensorContract_Avx_Single_4_8(contract, A, B, C, J, t, tmode, true, u, 285 v); 286 } else { 287 // Blocks of 8 columns 288 if (C >= blksize) 289 CeedTensorContract_Avx_Blocked_4_8(contract, A, B, C, J, t, tmode, true, 290 u, v); 291 // Remainder of columns 292 if (C % blksize) 293 CeedTensorContract_Avx_Remainder_8_8(contract, A, B, C, J, t, tmode, true, 294 u, v); 295 } 296 297 return CEED_ERROR_SUCCESS; 298 } 299 300 //------------------------------------------------------------------------------ 301 // Tensor Contract Destroy 302 //------------------------------------------------------------------------------ 303 static int CeedTensorContractDestroy_Avx(CeedTensorContract contract) { 304 return CEED_ERROR_SUCCESS; 305 } 306 307 //------------------------------------------------------------------------------ 308 // Tensor Contract Create 309 //------------------------------------------------------------------------------ 310 int CeedTensorContractCreate_Avx(CeedBasis basis, CeedTensorContract contract) { 311 int ierr; 312 Ceed ceed; 313 ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChkBackend(ierr); 314 315 ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", 316 CeedTensorContractApply_Avx); CeedChkBackend(ierr); 317 ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy", 318 CeedTensorContractDestroy_Avx); CeedChkBackend(ierr); 319 320 return CEED_ERROR_SUCCESS; 321 } 322 //------------------------------------------------------------------------------ 323