1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC. 2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707. 3 // All Rights reserved. See files LICENSE and NOTICE for details. 4 // 5 // This file is part of CEED, a collection of benchmarks, miniapps, software 6 // libraries and APIs for efficient high-order finite element and spectral 7 // element discretizations for exascale applications. For more information and 8 // source code availability see http://github.com/ceed. 9 // 10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 11 // a collaborative effort of two U.S. Department of Energy organizations (Office 12 // of Science and the National Nuclear Security Administration) responsible for 13 // the planning and preparation of a capable exascale ecosystem, including 14 // software, applications, hardware, advanced system engineering and early 15 // testbed platforms, in support of the nation's exascale computing imperative. 16 17 #include "ceed-avx.h" 18 19 // c += a * b 20 #ifdef __FMA__ 21 # define fmadd(c,a,b) (c) = _mm256_fmadd_pd((a), (b), (c)) 22 #else 23 # define fmadd(c,a,b) (c) += _mm256_mul_pd((a), (b)) 24 #endif 25 26 //------------------------------------------------------------------------------ 27 // Blocked Tensor Contract 28 //------------------------------------------------------------------------------ 29 static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract, 30 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 31 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 32 CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) { 33 CeedInt tstride0 = B, tstride1 = 1; 34 if (tmode == CEED_TRANSPOSE) { 35 tstride0 = 1; tstride1 = J; 36 } 37 38 for (CeedInt a=0; a<A; a++) { 39 // Blocks of 4 rows 40 for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) { 41 for (CeedInt c=0; c<(C/CC)*CC; c+=CC) { 42 __m256d vv[JJ][CC/4]; // Output tile to be held in registers 43 for (CeedInt jj=0; jj<JJ; jj++) 44 for (CeedInt cc=0; cc<CC/4; cc++) 45 vv[jj][cc] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c+cc*4]); 46 47 for (CeedInt b=0; b<B; b++) { 48 for (CeedInt jj=0; jj<JJ; jj++) { // unroll 49 __m256d tqv = _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1]); 50 for (CeedInt cc=0; cc<CC/4; cc++) // unroll 51 fmadd(vv[jj][cc], tqv, _mm256_loadu_pd(&u[(a*B+b)*C+c+cc*4])); 52 } 53 } 54 for (CeedInt jj=0; jj<JJ; jj++) 55 for (CeedInt cc=0; cc<CC/4; cc++) 56 _mm256_storeu_pd(&v[(a*J+j+jj)*C+c+cc*4], vv[jj][cc]); 57 } 58 } 59 // Remainder of rows 60 CeedInt j=(J/JJ)*JJ; 61 if (j < J) { 62 for (CeedInt c=0; c<(C/CC)*CC; c+=CC) { 63 __m256d vv[JJ][CC/4]; // Output tile to be held in registers 64 for (CeedInt jj=0; jj<J-j; jj++) 65 for (CeedInt cc=0; cc<CC/4; cc++) 66 vv[jj][cc] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c+cc*4]); 67 68 for (CeedInt b=0; b<B; b++) { 69 for (CeedInt jj=0; jj<J-j; jj++) { // doesn't unroll 70 __m256d tqv = _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1]); 71 for (CeedInt cc=0; cc<CC/4; cc++) // unroll 72 fmadd(vv[jj][cc], tqv, _mm256_loadu_pd(&u[(a*B+b)*C+c+cc*4])); 73 } 74 } 75 for (CeedInt jj=0; jj<J-j; jj++) 76 for (CeedInt cc=0; cc<CC/4; cc++) 77 _mm256_storeu_pd(&v[(a*J+j+jj)*C+c+cc*4], vv[jj][cc]); 78 } 79 } 80 } 81 return 0; 82 } 83 84 //------------------------------------------------------------------------------ 85 // Serial Tensor Contract Remainder 86 //------------------------------------------------------------------------------ 87 static inline int CeedTensorContract_Avx_Remainder(CeedTensorContract contract, 88 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 89 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 90 CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) { 91 CeedInt tstride0 = B, tstride1 = 1; 92 if (tmode == CEED_TRANSPOSE) { 93 tstride0 = 1; tstride1 = J; 94 } 95 96 CeedInt Jbreak = J%JJ ? (J/JJ)*JJ : (J/JJ-1)*JJ; 97 for (CeedInt a=0; a<A; a++) { 98 // Blocks of 4 columns 99 for (CeedInt c = (C/CC)*CC; c<C; c+=4) { 100 // Blocks of 4 rows 101 for (CeedInt j=0; j<Jbreak; j+=JJ) { 102 __m256d vv[JJ]; // Output tile to be held in registers 103 for (CeedInt jj=0; jj<JJ; jj++) 104 vv[jj] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c]); 105 106 for (CeedInt b=0; b<B; b++) { 107 __m256d tqu; 108 if (C-c == 1) 109 tqu = _mm256_set_pd(0.0, 0.0, 0.0, u[(a*B+b)*C+c+0]); 110 else if (C-c == 2) 111 tqu = _mm256_set_pd(0.0, 0.0, u[(a*B+b)*C+c+1], 112 u[(a*B+b)*C+c+0]); 113 else if (C-c == 3) 114 tqu = _mm256_set_pd(0.0, u[(a*B+b)*C+c+2], u[(a*B+b)*C+c+1], 115 u[(a*B+b)*C+c+0]); 116 else 117 tqu = _mm256_loadu_pd(&u[(a*B+b)*C+c]); 118 for (CeedInt jj=0; jj<JJ; jj++) // unroll 119 fmadd(vv[jj], tqu, _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1])); 120 } 121 for (CeedInt jj=0; jj<JJ; jj++) 122 _mm256_storeu_pd(&v[(a*J+j+jj)*C+c], vv[jj]); 123 } 124 } 125 // Remainder of rows, all columns 126 for (CeedInt j=Jbreak; j<J; j++) 127 for (CeedInt b=0; b<B; b++) { 128 CeedScalar tq = t[j*tstride0 + b*tstride1]; 129 for (CeedInt c=(C/CC)*CC; c<C; c++) 130 v[(a*J+j)*C+c] += tq * u[(a*B+b)*C+c]; 131 } 132 } 133 return 0; 134 } 135 136 //------------------------------------------------------------------------------ 137 // Serial Tensor Contract C=1 138 //------------------------------------------------------------------------------ 139 static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract, 140 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 141 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 142 CeedScalar *restrict v, const CeedInt AA, const CeedInt JJ) { 143 CeedInt tstride0 = B, tstride1 = 1; 144 if (tmode == CEED_TRANSPOSE) { 145 tstride0 = 1; tstride1 = J; 146 } 147 148 // Blocks of 4 rows 149 for (CeedInt a=0; a<(A/AA)*AA; a+=AA) { 150 for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) { 151 __m256d vv[AA][JJ/4]; // Output tile to be held in registers 152 for (CeedInt aa=0; aa<AA; aa++) 153 for (CeedInt jj=0; jj<JJ/4; jj++) 154 vv[aa][jj] = _mm256_loadu_pd(&v[(a+aa)*J+j+jj*4]); 155 156 for (CeedInt b=0; b<B; b++) { 157 for (CeedInt jj=0; jj<JJ/4; jj++) { // unroll 158 __m256d tqv = _mm256_set_pd(t[(j+jj*4+3)*tstride0 + b*tstride1], 159 t[(j+jj*4+2)*tstride0 + b*tstride1], 160 t[(j+jj*4+1)*tstride0 + b*tstride1], 161 t[(j+jj*4+0)*tstride0 + b*tstride1]); 162 for (CeedInt aa=0; aa<AA; aa++) // unroll 163 fmadd(vv[aa][jj], tqv, _mm256_set1_pd(u[(a+aa)*B+b])); 164 } 165 } 166 for (CeedInt aa=0; aa<AA; aa++) 167 for (CeedInt jj=0; jj<JJ/4; jj++) 168 _mm256_storeu_pd(&v[(a+aa)*J+j+jj*4], vv[aa][jj]); 169 } 170 } 171 // Remainder of rows 172 CeedInt a=(A/AA)*AA; 173 for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) { 174 __m256d vv[AA][JJ/4]; // Output tile to be held in registers 175 for (CeedInt aa=0; aa<A-a; aa++) 176 for (CeedInt jj=0; jj<JJ/4; jj++) 177 vv[aa][jj] = _mm256_loadu_pd(&v[(a+aa)*J+j+jj*4]); 178 179 for (CeedInt b=0; b<B; b++) { 180 for (CeedInt jj=0; jj<JJ/4; jj++) { // unroll 181 __m256d tqv = _mm256_set_pd(t[(j+jj*4+3)*tstride0 + b*tstride1], 182 t[(j+jj*4+2)*tstride0 + b*tstride1], 183 t[(j+jj*4+1)*tstride0 + b*tstride1], 184 t[(j+jj*4+0)*tstride0 + b*tstride1]); 185 for (CeedInt aa=0; aa<A-a; aa++) // unroll 186 fmadd(vv[aa][jj], tqv, _mm256_set1_pd(u[(a+aa)*B+b])); 187 } 188 } 189 for (CeedInt aa=0; aa<A-a; aa++) 190 for (CeedInt jj=0; jj<JJ/4; jj++) 191 _mm256_storeu_pd(&v[(a+aa)*J+j+jj*4], vv[aa][jj]); 192 } 193 // Column remainder 194 CeedInt Abreak = A%AA ? (A/AA)*AA : (A/AA-1)*AA; 195 // Blocks of 4 columns 196 for (CeedInt j = (J/JJ)*JJ; j<J; j+=4) { 197 // Blocks of 4 rows 198 for (CeedInt a=0; a<Abreak; a+=AA) { 199 __m256d vv[AA]; // Output tile to be held in registers 200 for (CeedInt aa=0; aa<AA; aa++) 201 vv[aa] = _mm256_loadu_pd(&v[(a+aa)*J+j]); 202 203 for (CeedInt b=0; b<B; b++) { 204 __m256d tqv; 205 if (J-j == 1) 206 tqv = _mm256_set_pd(0.0, 0.0, 0.0, t[(j+0)*tstride0 + b*tstride1]); 207 else if (J-j == 2) 208 tqv = _mm256_set_pd(0.0, 0.0, t[(j+1)*tstride0 + b*tstride1], 209 t[(j+0)*tstride0 + b*tstride1]); 210 else if (J-3 == j) 211 tqv = _mm256_set_pd(0.0, t[(j+2)*tstride0 + b*tstride1], 212 t[(j+1)*tstride0 + b*tstride1], 213 t[(j+0)*tstride0 + b*tstride1]); 214 else 215 tqv = _mm256_set_pd(t[(j+3)*tstride0 + b*tstride1], 216 t[(j+2)*tstride0 + b*tstride1], 217 t[(j+1)*tstride0 + b*tstride1], 218 t[(j+0)*tstride0 + b*tstride1]); 219 for (CeedInt aa=0; aa<AA; aa++) // unroll 220 fmadd(vv[aa], tqv, _mm256_set1_pd(u[(a+aa)*B+b])); 221 } 222 for (CeedInt aa=0; aa<AA; aa++) 223 _mm256_storeu_pd(&v[(a+aa)*J+j], vv[aa]); 224 } 225 } 226 // Remainder of rows, all columns 227 for (CeedInt b=0; b<B; b++) { 228 for (CeedInt j=(J/JJ)*JJ; j<J; j++) { 229 CeedScalar tq = t[j*tstride0 + b*tstride1]; 230 for (CeedInt a=Abreak; a<A; a++) 231 v[a*J+j] += tq * u[a*B+b]; 232 } 233 } 234 return 0; 235 } 236 237 //------------------------------------------------------------------------------ 238 // Tensor Contract - Common Sizes 239 //------------------------------------------------------------------------------ 240 static int CeedTensorContract_Avx_Blocked_4_8(CeedTensorContract contract, 241 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 242 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 243 CeedScalar *restrict v) { 244 return CeedTensorContract_Avx_Blocked(contract, A, B, C, J, t, tmode, Add, u, 245 v, 4, 8); 246 } 247 static int CeedTensorContract_Avx_Remainder_8_8(CeedTensorContract contract, 248 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 249 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 250 CeedScalar *restrict v) { 251 return CeedTensorContract_Avx_Remainder(contract, A, B, C, J, t, tmode, Add, 252 u, v, 8, 8); 253 } 254 static int CeedTensorContract_Avx_Single_4_8(CeedTensorContract contract, 255 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 256 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 257 CeedScalar *restrict v) { 258 return CeedTensorContract_Avx_Single(contract, A, B, C, J, t, tmode, Add, u, 259 v, 4, 8); 260 } 261 262 //------------------------------------------------------------------------------ 263 // Tensor Contract Apply 264 //------------------------------------------------------------------------------ 265 static int CeedTensorContractApply_Avx(CeedTensorContract contract, CeedInt A, 266 CeedInt B, CeedInt C, CeedInt J, 267 const CeedScalar *restrict t, 268 CeedTransposeMode tmode, 269 const CeedInt Add, 270 const CeedScalar *restrict u, 271 CeedScalar *restrict v) { 272 const CeedInt blksize = 8; 273 274 if (!Add) 275 for (CeedInt q=0; q<A*J*C; q++) 276 v[q] = (CeedScalar) 0.0; 277 278 if (C == 1) { 279 // Serial C=1 Case 280 CeedTensorContract_Avx_Single_4_8(contract, A, B, C, J, t, tmode, true, u, 281 v); 282 } else { 283 // Blocks of 8 columns 284 if (C >= blksize) 285 CeedTensorContract_Avx_Blocked_4_8(contract, A, B, C, J, t, tmode, true, 286 u, v); 287 // Remainder of columns 288 if (C % blksize) 289 CeedTensorContract_Avx_Remainder_8_8(contract, A, B, C, J, t, tmode, true, 290 u, v); 291 } 292 293 return 0; 294 } 295 296 //------------------------------------------------------------------------------ 297 // Tensor Contract Destroy 298 //------------------------------------------------------------------------------ 299 static int CeedTensorContractDestroy_Avx(CeedTensorContract contract) { 300 return 0; 301 } 302 303 //------------------------------------------------------------------------------ 304 // Tensor Contract Create 305 //------------------------------------------------------------------------------ 306 int CeedTensorContractCreate_Avx(CeedBasis basis, CeedTensorContract contract) { 307 int ierr; 308 Ceed ceed; 309 ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChk(ierr); 310 311 ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", 312 CeedTensorContractApply_Avx); CeedChk(ierr); 313 ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy", 314 CeedTensorContractDestroy_Avx); CeedChk(ierr); 315 316 return 0; 317 } 318 //------------------------------------------------------------------------------ 319