1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC. 2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707. 3 // All Rights reserved. See files LICENSE and NOTICE for details. 4 // 5 // This file is part of CEED, a collection of benchmarks, miniapps, software 6 // libraries and APIs for efficient high-order finite element and spectral 7 // element discretizations for exascale applications. For more information and 8 // source code availability see http://github.com/ceed. 9 // 10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 11 // a collaborative effort of two U.S. Department of Energy organizations (Office 12 // of Science and the National Nuclear Security Administration) responsible for 13 // the planning and preparation of a capable exascale ecosystem, including 14 // software, applications, hardware, advanced system engineering and early 15 // testbed platforms, in support of the nation's exascale computing imperative. 16 17 #include "ceed-xsmm.h" 18 19 //------------------------------------------------------------------------------ 20 // Tensor Contract C=1 21 //------------------------------------------------------------------------------ 22 static int CeedTensorContract_Xsmm_C1(CeedTensorContract contract, 23 CeedInt A, CeedInt B, CeedInt C, 24 CeedInt J, const CeedScalar *restrict t, 25 CeedTransposeMode tmode, 26 const CeedInt add, 27 const CeedScalar *restrict u, 28 CeedScalar *restrict v) { 29 CeedScalar alpha = 1.0, beta = 1.0; 30 char transu = 'N', transt = 'N'; 31 if ((tmode == CEED_TRANSPOSE && C != 1) 32 || (tmode == CEED_NOTRANSPOSE && C == 1)) 33 transt = 'T'; 34 35 if (!add) 36 beta = 0.0; 37 38 // libXSMM GEMM 39 libxsmm_dgemm(&transt, &transu, &J, &A, &B, 40 &alpha, &t[0], NULL, &u[0], NULL, 41 &beta, &v[0], NULL); 42 43 return 0; 44 } 45 46 //------------------------------------------------------------------------------ 47 // Tensor Contract Apply 48 //------------------------------------------------------------------------------ 49 static int CeedTensorContractApply_Xsmm(CeedTensorContract contract, CeedInt A, 50 CeedInt B, CeedInt C, CeedInt J, 51 const CeedScalar *restrict t, 52 CeedTransposeMode tmode, 53 const CeedInt add, 54 const CeedScalar *restrict u, 55 CeedScalar *restrict v) { 56 int ierr; 57 CeedTensorContract_Xsmm *impl; 58 ierr = CeedTensorContractGetData(contract, (void *)&impl); CeedChk(ierr); 59 60 // Get kernel 61 libxsmm_dmmfunction kernel; 62 CeedHashIJKLMKey key = {B, C, J, tmode, add}; 63 khint_t k = kh_get(m32, impl->lookup, key); 64 CeedHashGetValue(impl->lookup, k, kernel); 65 66 // Run kernel or fallback to default implementation 67 if (C != 1) 68 for (CeedInt a=0; a<A; a++) 69 kernel(&u[a*B*C], &t[0], &v[a*J*C], NULL, NULL, NULL); 70 else 71 CeedTensorContract_Xsmm_C1(contract, A, B, C, J, t, tmode, add, u, v); 72 73 return 0; 74 } 75 76 //------------------------------------------------------------------------------ 77 // Tensor Contract Destroy 78 //------------------------------------------------------------------------------ 79 static int CeedTensorContractDestroy_Xsmm(CeedTensorContract contract) { 80 int ierr; 81 CeedTensorContract_Xsmm *impl; 82 libxsmm_dmmfunction kernel; 83 84 ierr = CeedTensorContractGetData(contract, (void *)&impl); CeedChk(ierr); 85 // Free kernels 86 kh_foreach_value(impl->lookup, kernel, libxsmm_release_kernel(&kernel)); 87 kh_destroy(m32, impl->lookup); 88 ierr = CeedFree(&impl); CeedChk(ierr); 89 return 0; 90 } 91 92 //------------------------------------------------------------------------------ 93 // Tensor Contract Create 94 //------------------------------------------------------------------------------ 95 int CeedTensorContractCreate_Xsmm(CeedBasis basis, 96 CeedTensorContract contract) { 97 int ierr; 98 Ceed ceed; 99 ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChk(ierr); 100 CeedTensorContract_Xsmm *impl; 101 ierr = CeedCalloc(1, &impl); CeedChk(ierr); 102 103 // Setup kernels hash table 104 impl->lookup = kh_init(m32); 105 106 // Set up pointers to kernels 107 ierr = CeedBasisIsTensor(basis, &impl->isTensor); CeedChk(ierr); 108 if (impl->isTensor) { 109 ierr = CeedBasisGetNumNodes1D(basis, &impl->P); CeedChk(ierr); 110 ierr = CeedBasisGetNumQuadraturePoints1D(basis, &impl->Q); CeedChk(ierr); 111 ierr = CeedBasisGetDimension(basis, &impl->dim); CeedChk(ierr); 112 // Build all required kernels 113 for (CeedInt nelem = 1; nelem <= 8; nelem+=7) 114 for (CeedInt add = 0; add <= 1; add++) 115 for (CeedInt tmode = 0; tmode <= 1; tmode++) 116 for (CeedInt grad = 0; grad <=1; grad++) 117 for (CeedInt dim = 0; dim < impl->dim; dim++) { 118 const int flags = LIBXSMM_GEMM_FLAGS('N', tmode ? 'T' : 'N'); 119 CeedInt B = grad ? impl->Q : (tmode ? impl->Q : impl->P), 120 J = grad ? impl->Q : (tmode ? impl->P : impl->Q), 121 C = nelem*CeedIntPow(J, dim); 122 // Add key, kernel pair to hash table 123 CeedHashIJKLMKey key = {B, C, J, tmode, add}; 124 int new_item; 125 khint_t k = kh_put(m32, impl->lookup, key, &new_item); 126 if (new_item) { 127 // Build kernel 128 CeedScalar alpha = 1.0, beta = 1.0; 129 if (!add) beta = 0.0; 130 libxsmm_dmmfunction kernel = libxsmm_dmmdispatch( 131 C, J, B, NULL, NULL, NULL, &alpha, &beta, &flags, NULL); 132 if (!kernel) 133 // LCOV_EXCL_START 134 return CeedError(ceed, 1, "LIBXSMM kernel failed to build."); 135 // LCOV_EXCL_STOP 136 // Add kernel to hash table 137 kh_value(impl->lookup, k) = kernel; 138 } 139 } 140 } else { 141 ierr = CeedBasisGetNumNodes(basis, &impl->P); CeedChk(ierr); 142 ierr = CeedBasisGetNumQuadraturePoints(basis, &impl->Q); CeedChk(ierr); 143 ierr = CeedBasisGetDimension(basis, &impl->dim); CeedChk(ierr); 144 // Build all required kernels 145 for (CeedInt nelem = 1; nelem <= 8; nelem+=7) 146 for (CeedInt add = 0; add <= 1; add++) 147 for (CeedInt tmode = 0; tmode <= 1; tmode++) { 148 CeedInt gradstride = CeedIntMax(impl->dim-1, 1); 149 for (CeedInt grad = 1; grad <= impl->dim; grad+=gradstride) { 150 const int flags = LIBXSMM_GEMM_FLAGS('N', tmode ? 'T' : 'N'); 151 CeedInt B = tmode ? grad*impl->Q : impl->P, 152 J = tmode ? impl->P : grad*impl->Q, 153 C = nelem; 154 // Add key, kernel pair to hash table 155 CeedHashIJKLMKey key = {B, C, J, tmode, add}; 156 int new_item; 157 khint_t k = kh_put(m32, impl->lookup, key, &new_item); 158 if (new_item) { 159 // Build kernel 160 CeedScalar alpha = 1.0, beta = 1.0; 161 if (!add) beta = 0.0; 162 libxsmm_dmmfunction kernel = libxsmm_dmmdispatch( 163 C, J, B, NULL, NULL, NULL, &alpha, &beta, &flags, NULL); 164 if (!kernel) 165 // LCOV_EXCL_START 166 return CeedError(ceed, 1, "LIBXSMM kernel failed to build."); 167 // LCOV_EXCL_STOP 168 // Add kernel to hash table 169 kh_value(impl->lookup, k) = kernel; 170 } 171 } 172 } 173 } 174 ierr = CeedTensorContractSetData(contract, (void *)&impl); CeedChk(ierr); 175 176 ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", 177 CeedTensorContractApply_Xsmm); CeedChk(ierr); 178 ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy", 179 CeedTensorContractDestroy_Xsmm); CeedChk(ierr); 180 181 return 0; 182 } 183 //------------------------------------------------------------------------------ 184