1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC. 2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707. 3 // All Rights reserved. See files LICENSE and NOTICE for details. 4 // 5 // This file is part of CEED, a collection of benchmarks, miniapps, software 6 // libraries and APIs for efficient high-order finite element and spectral 7 // element discretizations for exascale applications. For more information and 8 // source code availability see http://github.com/ceed. 9 // 10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 11 // a collaborative effort of two U.S. Department of Energy organizations (Office 12 // of Science and the National Nuclear Security Administration) responsible for 13 // the planning and preparation of a capable exascale ecosystem, including 14 // software, applications, hardware, advanced system engineering and early 15 // testbed platforms, in support of the nation's exascale computing imperative. 16 17 #include "ceed-avx.h" 18 19 // Blocked Tensor Contact 20 static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract, 21 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 22 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 23 CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) { 24 CeedInt tstride0 = B, tstride1 = 1; 25 if (tmode == CEED_TRANSPOSE) { 26 tstride0 = 1; tstride1 = J; 27 } 28 29 if (!Add) 30 for (CeedInt q=0; q<A*J*C; q++) 31 v[q] = (CeedScalar) 0.0; 32 33 for (CeedInt a=0; a<A; a++) { 34 // Blocks of 4 rows 35 for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) { 36 for (CeedInt c=0; c<(C/CC)*CC; c+=CC) { 37 __m256d vv[JJ][CC/4]; // Output tile to be held in registers 38 for (CeedInt jj=0; jj<JJ; jj++) 39 for (CeedInt cc=0; cc<CC/4; cc++) 40 vv[jj][cc] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c+cc*4]); 41 42 for (CeedInt b=0; b<B; b++) { 43 for (CeedInt jj=0; jj<JJ; jj++) { // unroll 44 __m256d tqv = _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1]); 45 for (CeedInt cc=0; cc<CC/4; cc++) // unroll 46 vv[jj][cc] += _mm256_mul_pd(tqv, 47 _mm256_loadu_pd(&u[(a*B+b)*C+c+cc*4])); 48 } 49 } 50 for (CeedInt jj=0; jj<JJ; jj++) 51 for (CeedInt cc=0; cc<CC/4; cc++) 52 _mm256_storeu_pd(&v[(a*J+j+jj)*C+c+cc*4], vv[jj][cc]); 53 } 54 } 55 // Remainder of rows 56 CeedInt j=(J/JJ)*JJ; 57 if (j < J) { 58 for (CeedInt c=0; c<(C/CC)*CC; c+=CC) { 59 __m256d vv[JJ][CC/4]; // Output tile to be held in registers 60 for (CeedInt jj=0; jj<J-j; jj++) 61 for (CeedInt cc=0; cc<CC/4; cc++) 62 vv[jj][cc] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c+cc*4]); 63 64 for (CeedInt b=0; b<B; b++) { 65 for (CeedInt jj=0; jj<J-j; jj++) { // doesn't unroll 66 __m256d tqv = _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1]); 67 for (CeedInt cc=0; cc<CC/4; cc++) // unroll 68 vv[jj][cc] += _mm256_mul_pd(tqv, 69 _mm256_loadu_pd(&u[(a*B+b)*C+c+cc*4])); 70 } 71 } 72 for (CeedInt jj=0; jj<J-j; jj++) 73 for (CeedInt cc=0; cc<CC/4; cc++) 74 _mm256_storeu_pd(&v[(a*J+j+jj)*C+c+cc*4], vv[jj][cc]); 75 } 76 } 77 } 78 return 0; 79 } 80 81 // Serial Tensor Contract Remainder 82 static inline int CeedTensorContract_Avx_Remainder(CeedTensorContract contract, 83 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 84 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 85 CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) { 86 CeedInt tstride0 = B, tstride1 = 1; 87 if (tmode == CEED_TRANSPOSE) { 88 tstride0 = 1; tstride1 = J; 89 } 90 91 CeedInt Jbreak = J%JJ ? (J/JJ)*JJ : (J/JJ-1)*JJ; 92 for (CeedInt a=0; a<A; a++) { 93 // Blocks of 4 columns 94 for (CeedInt c = (C/CC)*CC; c<C; c+=4) { 95 // Blocks of 4 rows 96 for (CeedInt j=0; j<Jbreak; j+=JJ) { 97 __m256d vv[JJ]; // Output tile to be held in registers 98 for (CeedInt jj=0; jj<JJ; jj++) 99 vv[jj] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c]); 100 101 for (CeedInt b=0; b<B; b++) { 102 __m256d tqu; 103 if (C-c == 1) 104 tqu = _mm256_set_pd(0.0, 0.0, 0.0, u[(a*B+b)*C+c+0]); 105 else if (C-c == 2) 106 tqu = _mm256_set_pd(0.0, 0.0, u[(a*B+b)*C+c+1], 107 u[(a*B+b)*C+c+0]); 108 else if (C-c == 3) 109 tqu = _mm256_set_pd(0.0, u[(a*B+b)*C+c+2], u[(a*B+b)*C+c+1], 110 u[(a*B+b)*C+c+0]); 111 else 112 tqu = _mm256_loadu_pd(&u[(a*B+b)*C+c]); 113 for (CeedInt jj=0; jj<JJ; jj++) // unroll 114 vv[jj] += _mm256_mul_pd(tqu, 115 _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1])); 116 } 117 for (CeedInt jj=0; jj<JJ; jj++) 118 _mm256_storeu_pd(&v[(a*J+j+jj)*C+c], vv[jj]); 119 } 120 } 121 // Remainder of rows, all columns 122 for (CeedInt j=Jbreak; j<J; j++) 123 for (CeedInt b=0; b<B; b++) { 124 CeedScalar tq = t[j*tstride0 + b*tstride1]; 125 for (CeedInt c=(C/CC)*CC; c<C; c++) 126 v[(a*J+j)*C+c] += tq * u[(a*B+b)*C+c]; 127 } 128 } 129 return 0; 130 } 131 132 // Serial Tensor Contract C=1 Case 133 static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract, 134 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 135 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 136 CeedScalar *restrict v, const CeedInt AA, const CeedInt JJ) { 137 CeedInt tstride0 = B, tstride1 = 1; 138 if (tmode == CEED_TRANSPOSE) { 139 tstride0 = 1; tstride1 = J; 140 } 141 142 if (!Add) 143 for (CeedInt q=0; q<A*J*C; q++) 144 v[q] = (CeedScalar) 0.0; 145 146 // Blocks of 4 rows 147 for (CeedInt a=0; a<(A/AA)*AA; a+=AA) { 148 for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) { 149 __m256d vv[AA][JJ/4]; // Output tile to be held in registers 150 for (CeedInt aa=0; aa<AA; aa++) 151 for (CeedInt jj=0; jj<JJ/4; jj++) 152 vv[aa][jj] = _mm256_loadu_pd(&v[(a+aa)*J+j+jj*4]); 153 154 for (CeedInt b=0; b<B; b++) { 155 for (CeedInt jj=0; jj<JJ/4; jj++) { // unroll 156 __m256d tqv = _mm256_set_pd(t[(j+jj*4+3)*tstride0 + b*tstride1], 157 t[(j+jj*4+2)*tstride0 + b*tstride1], 158 t[(j+jj*4+1)*tstride0 + b*tstride1], 159 t[(j+jj*4+0)*tstride0 + b*tstride1]); 160 for (CeedInt aa=0; aa<AA; aa++) // unroll 161 vv[aa][jj] += _mm256_mul_pd(tqv, _mm256_set1_pd(u[(a+aa)*B+b])); 162 } 163 } 164 for (CeedInt aa=0; aa<AA; aa++) 165 for (CeedInt jj=0; jj<JJ/4; jj++) 166 _mm256_storeu_pd(&v[(a+aa)*J+j+jj*4], vv[aa][jj]); 167 } 168 } 169 // Remainder of rows 170 CeedInt a=(A/AA)*AA; 171 for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) { 172 __m256d vv[AA][JJ/4]; // Output tile to be held in registers 173 for (CeedInt aa=0; aa<A-a; aa++) 174 for (CeedInt jj=0; jj<JJ/4; jj++) 175 vv[aa][jj] = _mm256_loadu_pd(&v[(a+aa)*J+j+jj*4]); 176 177 for (CeedInt b=0; b<B; b++) { 178 for (CeedInt jj=0; jj<JJ/4; jj++) { // unroll 179 __m256d tqv = _mm256_set_pd(t[(j+jj*4+3)*tstride0 + b*tstride1], 180 t[(j+jj*4+2)*tstride0 + b*tstride1], 181 t[(j+jj*4+1)*tstride0 + b*tstride1], 182 t[(j+jj*4+0)*tstride0 + b*tstride1]); 183 for (CeedInt aa=0; aa<A-a; aa++) // unroll 184 vv[aa][jj] += _mm256_mul_pd(tqv, _mm256_set1_pd(u[(a+aa)*B+b])); 185 } 186 } 187 for (CeedInt aa=0; aa<A-a; aa++) 188 for (CeedInt jj=0; jj<JJ/4; jj++) 189 _mm256_storeu_pd(&v[(a+aa)*J+j+jj*4], vv[aa][jj]); 190 } 191 // Column remainder 192 CeedInt Abreak = A%AA ? (A/AA)*AA : (A/AA-1)*AA; 193 // Blocks of 4 columns 194 for (CeedInt j = (J/JJ)*JJ; j<J; j+=4) { 195 // Blocks of 4 rows 196 for (CeedInt a=0; a<Abreak; a+=AA) { 197 __m256d vv[AA]; // Output tile to be held in registers 198 for (CeedInt aa=0; aa<AA; aa++) 199 vv[aa] = _mm256_loadu_pd(&v[(a+aa)*J+j]); 200 201 for (CeedInt b=0; b<B; b++) { 202 __m256d tqv; 203 if (J-j == 1) 204 tqv = _mm256_set_pd(0.0, 0.0, 0.0, t[(j+0)*tstride0 + b*tstride1]); 205 else if (J-j == 2) 206 tqv = _mm256_set_pd(0.0, 0.0, t[(j+1)*tstride0 + b*tstride1], 207 t[(j+0)*tstride0 + b*tstride1]); 208 else if (J-3 == j) 209 tqv = _mm256_set_pd(0.0, t[(j+2)*tstride0 + b*tstride1], 210 t[(j+1)*tstride0 + b*tstride1], 211 t[(j+0)*tstride0 + b*tstride1]); 212 else 213 tqv = _mm256_set_pd(t[(j+3)*tstride0 + b*tstride1], 214 t[(j+2)*tstride0 + b*tstride1], 215 t[(j+1)*tstride0 + b*tstride1], 216 t[(j+0)*tstride0 + b*tstride1]); 217 for (CeedInt aa=0; aa<AA; aa++) // unroll 218 vv[aa] += _mm256_mul_pd(tqv, _mm256_set1_pd(u[(a+aa)*B+b])); 219 } 220 for (CeedInt aa=0; aa<AA; aa++) 221 _mm256_storeu_pd(&v[(a+aa)*J+j], vv[aa]); 222 } 223 } 224 // Remainder of rows, all columns 225 for (CeedInt b=0; b<B; b++) { 226 for (CeedInt j=(J/JJ)*JJ; j<J; j++) { 227 CeedScalar tq = t[j*tstride0 + b*tstride1]; 228 for (CeedInt a=Abreak; a<A; a++) 229 v[a*J+j] += tq * u[a*B+b]; 230 } 231 } 232 return 0; 233 } 234 235 // Specific Variants 236 static int CeedTensorContract_Avx_Blocked_4_8(CeedTensorContract contract, 237 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 238 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 239 CeedScalar *restrict v) { 240 return CeedTensorContract_Avx_Blocked(contract, A, B, C, J, t, tmode, Add, u, 241 v, 4, 8); 242 } 243 static int CeedTensorContract_Avx_Remainder_8_8(CeedTensorContract contract, 244 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 245 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 246 CeedScalar *restrict v) { 247 return CeedTensorContract_Avx_Remainder(contract, A, B, C, J, t, tmode, Add, 248 u, v, 8, 8); 249 } 250 static int CeedTensorContract_Avx_Single_4_8(CeedTensorContract contract, 251 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 252 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 253 CeedScalar *restrict v) { 254 return CeedTensorContract_Avx_Single(contract, A, B, C, J, t, tmode, Add, u, 255 v, 4, 8); 256 } 257 258 // Switch for Tensor Contract 259 static int CeedTensorContractApply_Avx(CeedTensorContract contract, CeedInt A, 260 CeedInt B, CeedInt C, CeedInt J, 261 const CeedScalar *restrict t, 262 CeedTransposeMode tmode, 263 const CeedInt Add, 264 const CeedScalar *restrict u, 265 CeedScalar *restrict v) { 266 const CeedInt blksize = 8; 267 268 if (!Add) 269 for (CeedInt q=0; q<A*J*C; q++) 270 v[q] = (CeedScalar) 0.0; 271 272 if (C == 1) { 273 // Serial C=1 Case 274 CeedTensorContract_Avx_Single_4_8(contract, A, B, C, J, t, tmode, true, u, 275 v); 276 } else { 277 // Blocks of 8 columns 278 if (C >= blksize) 279 CeedTensorContract_Avx_Blocked_4_8(contract, A, B, C, J, t, tmode, true, 280 u, v); 281 // Remainder of columns 282 if (C % blksize) 283 CeedTensorContract_Avx_Remainder_8_8(contract, A, B, C, J, t, tmode, true, 284 u, v); 285 } 286 287 return 0; 288 } 289 290 static int CeedTensorContractDestroy_Avx(CeedTensorContract contract) { 291 return 0; 292 } 293 294 int CeedTensorContractCreate_Avx(CeedBasis basis, CeedTensorContract contract) { 295 int ierr; 296 Ceed ceed; 297 ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChk(ierr); 298 299 ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", 300 CeedTensorContractApply_Avx); CeedChk(ierr); 301 ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy", 302 CeedTensorContractDestroy_Avx); CeedChk(ierr); 303 304 return 0; 305 } 306