1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC. 2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707. 3 // All Rights reserved. See files LICENSE and NOTICE for details. 4 // 5 // This file is part of CEED, a collection of benchmarks, miniapps, software 6 // libraries and APIs for efficient high-order finite element and spectral 7 // element discretizations for exascale applications. For more information and 8 // source code availability see http://github.com/ceed. 9 // 10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 11 // a collaborative effort of two U.S. Department of Energy organizations (Office 12 // of Science and the National Nuclear Security Administration) responsible for 13 // the planning and preparation of a capable exascale ecosystem, including 14 // software, applications, hardware, advanced system engineering and early 15 // testbed platforms, in support of the nation's exascale computing imperative. 16 17 #include "ceed-avx.h" 18 19 // Blocked Tensor Contact 20 static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract, 21 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 22 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 23 CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) { 24 CeedInt tstride0 = B, tstride1 = 1; 25 if (tmode == CEED_TRANSPOSE) { 26 tstride0 = 1; tstride1 = J; 27 } 28 29 for (CeedInt a=0; a<A; a++) { 30 // Blocks of 4 rows 31 for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) { 32 for (CeedInt c=0; c<(C/CC)*CC; c+=CC) { 33 __m256d vv[JJ][CC/4]; // Output tile to be held in registers 34 for (CeedInt jj=0; jj<JJ; jj++) 35 for (CeedInt cc=0; cc<CC/4; cc++) 36 vv[jj][cc] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c+cc*4]); 37 38 for (CeedInt b=0; b<B; b++) { 39 for (CeedInt jj=0; jj<JJ; jj++) { // unroll 40 __m256d tqv = _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1]); 41 for (CeedInt cc=0; cc<CC/4; cc++) // unroll 42 vv[jj][cc] += _mm256_mul_pd(tqv, 43 _mm256_loadu_pd(&u[(a*B+b)*C+c+cc*4])); 44 } 45 } 46 for (CeedInt jj=0; jj<JJ; jj++) 47 for (CeedInt cc=0; cc<CC/4; cc++) 48 _mm256_storeu_pd(&v[(a*J+j+jj)*C+c+cc*4], vv[jj][cc]); 49 } 50 } 51 // Remainder of rows 52 CeedInt j=(J/JJ)*JJ; 53 if (j < J) { 54 for (CeedInt c=0; c<(C/CC)*CC; c+=CC) { 55 __m256d vv[JJ][CC/4]; // Output tile to be held in registers 56 for (CeedInt jj=0; jj<J-j; jj++) 57 for (CeedInt cc=0; cc<CC/4; cc++) 58 vv[jj][cc] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c+cc*4]); 59 60 for (CeedInt b=0; b<B; b++) { 61 for (CeedInt jj=0; jj<J-j; jj++) { // doesn't unroll 62 __m256d tqv = _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1]); 63 for (CeedInt cc=0; cc<CC/4; cc++) // unroll 64 vv[jj][cc] += _mm256_mul_pd(tqv, 65 _mm256_loadu_pd(&u[(a*B+b)*C+c+cc*4])); 66 } 67 } 68 for (CeedInt jj=0; jj<J-j; jj++) 69 for (CeedInt cc=0; cc<CC/4; cc++) 70 _mm256_storeu_pd(&v[(a*J+j+jj)*C+c+cc*4], vv[jj][cc]); 71 } 72 } 73 } 74 return 0; 75 } 76 77 // Serial Tensor Contract Remainder 78 static inline int CeedTensorContract_Avx_Remainder(CeedTensorContract contract, 79 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 80 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 81 CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) { 82 CeedInt tstride0 = B, tstride1 = 1; 83 if (tmode == CEED_TRANSPOSE) { 84 tstride0 = 1; tstride1 = J; 85 } 86 87 CeedInt Jbreak = J%JJ ? (J/JJ)*JJ : (J/JJ-1)*JJ; 88 for (CeedInt a=0; a<A; a++) { 89 // Blocks of 4 columns 90 for (CeedInt c = (C/CC)*CC; c<C; c+=4) { 91 // Blocks of 4 rows 92 for (CeedInt j=0; j<Jbreak; j+=JJ) { 93 __m256d vv[JJ]; // Output tile to be held in registers 94 for (CeedInt jj=0; jj<JJ; jj++) 95 vv[jj] = _mm256_loadu_pd(&v[(a*J+j+jj)*C+c]); 96 97 for (CeedInt b=0; b<B; b++) { 98 __m256d tqu; 99 if (C-c == 1) 100 tqu = _mm256_set_pd(0.0, 0.0, 0.0, u[(a*B+b)*C+c+0]); 101 else if (C-c == 2) 102 tqu = _mm256_set_pd(0.0, 0.0, u[(a*B+b)*C+c+1], 103 u[(a*B+b)*C+c+0]); 104 else if (C-c == 3) 105 tqu = _mm256_set_pd(0.0, u[(a*B+b)*C+c+2], u[(a*B+b)*C+c+1], 106 u[(a*B+b)*C+c+0]); 107 else 108 tqu = _mm256_loadu_pd(&u[(a*B+b)*C+c]); 109 for (CeedInt jj=0; jj<JJ; jj++) // unroll 110 vv[jj] += _mm256_mul_pd(tqu, 111 _mm256_set1_pd(t[(j+jj)*tstride0 + b*tstride1])); 112 } 113 for (CeedInt jj=0; jj<JJ; jj++) 114 _mm256_storeu_pd(&v[(a*J+j+jj)*C+c], vv[jj]); 115 } 116 } 117 // Remainder of rows, all columns 118 for (CeedInt j=Jbreak; j<J; j++) 119 for (CeedInt b=0; b<B; b++) { 120 CeedScalar tq = t[j*tstride0 + b*tstride1]; 121 for (CeedInt c=(C/CC)*CC; c<C; c++) 122 v[(a*J+j)*C+c] += tq * u[(a*B+b)*C+c]; 123 } 124 } 125 return 0; 126 } 127 128 // Serial Tensor Contract C=1 Case 129 static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract, 130 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 131 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 132 CeedScalar *restrict v, const CeedInt AA, const CeedInt JJ) { 133 CeedInt tstride0 = B, tstride1 = 1; 134 if (tmode == CEED_TRANSPOSE) { 135 tstride0 = 1; tstride1 = J; 136 } 137 138 // Blocks of 4 rows 139 for (CeedInt a=0; a<(A/AA)*AA; a+=AA) { 140 for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) { 141 __m256d vv[AA][JJ/4]; // Output tile to be held in registers 142 for (CeedInt aa=0; aa<AA; aa++) 143 for (CeedInt jj=0; jj<JJ/4; jj++) 144 vv[aa][jj] = _mm256_loadu_pd(&v[(a+aa)*J+j+jj*4]); 145 146 for (CeedInt b=0; b<B; b++) { 147 for (CeedInt jj=0; jj<JJ/4; jj++) { // unroll 148 __m256d tqv = _mm256_set_pd(t[(j+jj*4+3)*tstride0 + b*tstride1], 149 t[(j+jj*4+2)*tstride0 + b*tstride1], 150 t[(j+jj*4+1)*tstride0 + b*tstride1], 151 t[(j+jj*4+0)*tstride0 + b*tstride1]); 152 for (CeedInt aa=0; aa<AA; aa++) // unroll 153 vv[aa][jj] += _mm256_mul_pd(tqv, _mm256_set1_pd(u[(a+aa)*B+b])); 154 } 155 } 156 for (CeedInt aa=0; aa<AA; aa++) 157 for (CeedInt jj=0; jj<JJ/4; jj++) 158 _mm256_storeu_pd(&v[(a+aa)*J+j+jj*4], vv[aa][jj]); 159 } 160 } 161 // Remainder of rows 162 CeedInt a=(A/AA)*AA; 163 for (CeedInt j=0; j<(J/JJ)*JJ; j+=JJ) { 164 __m256d vv[AA][JJ/4]; // Output tile to be held in registers 165 for (CeedInt aa=0; aa<A-a; aa++) 166 for (CeedInt jj=0; jj<JJ/4; jj++) 167 vv[aa][jj] = _mm256_loadu_pd(&v[(a+aa)*J+j+jj*4]); 168 169 for (CeedInt b=0; b<B; b++) { 170 for (CeedInt jj=0; jj<JJ/4; jj++) { // unroll 171 __m256d tqv = _mm256_set_pd(t[(j+jj*4+3)*tstride0 + b*tstride1], 172 t[(j+jj*4+2)*tstride0 + b*tstride1], 173 t[(j+jj*4+1)*tstride0 + b*tstride1], 174 t[(j+jj*4+0)*tstride0 + b*tstride1]); 175 for (CeedInt aa=0; aa<A-a; aa++) // unroll 176 vv[aa][jj] += _mm256_mul_pd(tqv, _mm256_set1_pd(u[(a+aa)*B+b])); 177 } 178 } 179 for (CeedInt aa=0; aa<A-a; aa++) 180 for (CeedInt jj=0; jj<JJ/4; jj++) 181 _mm256_storeu_pd(&v[(a+aa)*J+j+jj*4], vv[aa][jj]); 182 } 183 // Column remainder 184 CeedInt Abreak = A%AA ? (A/AA)*AA : (A/AA-1)*AA; 185 // Blocks of 4 columns 186 for (CeedInt j = (J/JJ)*JJ; j<J; j+=4) { 187 // Blocks of 4 rows 188 for (CeedInt a=0; a<Abreak; a+=AA) { 189 __m256d vv[AA]; // Output tile to be held in registers 190 for (CeedInt aa=0; aa<AA; aa++) 191 vv[aa] = _mm256_loadu_pd(&v[(a+aa)*J+j]); 192 193 for (CeedInt b=0; b<B; b++) { 194 __m256d tqv; 195 if (J-j == 1) 196 tqv = _mm256_set_pd(0.0, 0.0, 0.0, t[(j+0)*tstride0 + b*tstride1]); 197 else if (J-j == 2) 198 tqv = _mm256_set_pd(0.0, 0.0, t[(j+1)*tstride0 + b*tstride1], 199 t[(j+0)*tstride0 + b*tstride1]); 200 else if (J-3 == j) 201 tqv = _mm256_set_pd(0.0, t[(j+2)*tstride0 + b*tstride1], 202 t[(j+1)*tstride0 + b*tstride1], 203 t[(j+0)*tstride0 + b*tstride1]); 204 else 205 tqv = _mm256_set_pd(t[(j+3)*tstride0 + b*tstride1], 206 t[(j+2)*tstride0 + b*tstride1], 207 t[(j+1)*tstride0 + b*tstride1], 208 t[(j+0)*tstride0 + b*tstride1]); 209 for (CeedInt aa=0; aa<AA; aa++) // unroll 210 vv[aa] += _mm256_mul_pd(tqv, _mm256_set1_pd(u[(a+aa)*B+b])); 211 } 212 for (CeedInt aa=0; aa<AA; aa++) 213 _mm256_storeu_pd(&v[(a+aa)*J+j], vv[aa]); 214 } 215 } 216 // Remainder of rows, all columns 217 for (CeedInt b=0; b<B; b++) { 218 for (CeedInt j=(J/JJ)*JJ; j<J; j++) { 219 CeedScalar tq = t[j*tstride0 + b*tstride1]; 220 for (CeedInt a=Abreak; a<A; a++) 221 v[a*J+j] += tq * u[a*B+b]; 222 } 223 } 224 return 0; 225 } 226 227 // Specific Variants 228 static int CeedTensorContract_Avx_Blocked_4_8(CeedTensorContract contract, 229 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 230 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 231 CeedScalar *restrict v) { 232 return CeedTensorContract_Avx_Blocked(contract, A, B, C, J, t, tmode, Add, u, 233 v, 4, 8); 234 } 235 static int CeedTensorContract_Avx_Remainder_8_8(CeedTensorContract contract, 236 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 237 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 238 CeedScalar *restrict v) { 239 return CeedTensorContract_Avx_Remainder(contract, A, B, C, J, t, tmode, Add, 240 u, v, 8, 8); 241 } 242 static int CeedTensorContract_Avx_Single_4_8(CeedTensorContract contract, 243 CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 244 CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u, 245 CeedScalar *restrict v) { 246 return CeedTensorContract_Avx_Single(contract, A, B, C, J, t, tmode, Add, u, 247 v, 4, 8); 248 } 249 250 // Switch for Tensor Contract 251 static int CeedTensorContractApply_Avx(CeedTensorContract contract, CeedInt A, 252 CeedInt B, CeedInt C, CeedInt J, 253 const CeedScalar *restrict t, 254 CeedTransposeMode tmode, 255 const CeedInt Add, 256 const CeedScalar *restrict u, 257 CeedScalar *restrict v) { 258 const CeedInt blksize = 8; 259 260 if (!Add) 261 for (CeedInt q=0; q<A*J*C; q++) 262 v[q] = (CeedScalar) 0.0; 263 264 if (C == 1) { 265 // Serial C=1 Case 266 CeedTensorContract_Avx_Single_4_8(contract, A, B, C, J, t, tmode, true, u, 267 v); 268 } else { 269 // Blocks of 8 columns 270 if (C >= blksize) 271 CeedTensorContract_Avx_Blocked_4_8(contract, A, B, C, J, t, tmode, true, 272 u, v); 273 // Remainder of columns 274 if (C % blksize) 275 CeedTensorContract_Avx_Remainder_8_8(contract, A, B, C, J, t, tmode, true, 276 u, v); 277 } 278 279 return 0; 280 } 281 282 static int CeedTensorContractDestroy_Avx(CeedTensorContract contract) { 283 return 0; 284 } 285 286 int CeedTensorContractCreate_Avx(CeedBasis basis, CeedTensorContract contract) { 287 int ierr; 288 Ceed ceed; 289 ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChk(ierr); 290 291 ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", 292 CeedTensorContractApply_Avx); CeedChk(ierr); 293 ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy", 294 CeedTensorContractDestroy_Avx); CeedChk(ierr); 295 296 return 0; 297 } 298