1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC. 2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707. 3 // All Rights reserved. See files LICENSE and NOTICE for details. 4 // 5 // This file is part of CEED, a collection of benchmarks, miniapps, software 6 // libraries and APIs for efficient high-order finite element and spectral 7 // element discretizations for exascale applications. For more information and 8 // source code availability see http://github.com/ceed. 9 // 10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 11 // a collaborative effort of two U.S. Department of Energy organizations (Office 12 // of Science and the National Nuclear Security Administration) responsible for 13 // the planning and preparation of a capable exascale ecosystem, including 14 // software, applications, hardware, advanced system engineering and early 15 // testbed platforms, in support of the nation's exascale computing imperative. 16 17 #include <ceed/ceed.h> 18 #include <ceed/backend.h> 19 #include <ceed/hash.h> 20 #include <ceed/khash.h> 21 #include <libxsmm.h> 22 #include <stddef.h> 23 #include "ceed-xsmm.h" 24 25 //------------------------------------------------------------------------------ 26 // Tensor Contract C=1 27 //------------------------------------------------------------------------------ 28 static int CeedTensorContract_Xsmm_C1(CeedTensorContract contract, 29 CeedInt A, CeedInt B, CeedInt C, 30 CeedInt J, const CeedScalar *restrict t, 31 CeedTransposeMode t_mode, 32 const CeedInt add, 33 const CeedScalar *restrict u, 34 CeedScalar *restrict v) { 35 CeedScalar alpha = 1.0, beta = 1.0; 36 char trans_u = 'N', trans_t = 'N'; 37 if ((t_mode == CEED_TRANSPOSE && C != 1) 38 || (t_mode == CEED_NOTRANSPOSE && C == 1)) 39 trans_t = 'T'; 40 41 if (!add) 42 beta = 0.0; 43 44 // libXSMM GEMM 45 libxsmm_dgemm(&trans_t, &trans_u, &J, &A, &B, 46 &alpha, &t[0], NULL, &u[0], NULL, 47 &beta, &v[0], NULL); 48 49 return CEED_ERROR_SUCCESS; 50 } 51 52 //------------------------------------------------------------------------------ 53 // Tensor Contract Apply 54 //------------------------------------------------------------------------------ 55 static int CeedTensorContractApply_Xsmm(CeedTensorContract contract, CeedInt A, 56 CeedInt B, CeedInt C, CeedInt J, 57 const CeedScalar *restrict t, 58 CeedTransposeMode t_mode, 59 const CeedInt add, 60 const CeedScalar *restrict u, 61 CeedScalar *restrict v) { 62 int ierr; 63 CeedTensorContract_Xsmm *impl; 64 ierr = CeedTensorContractGetData(contract, &impl); CeedChkBackend(ierr); 65 66 // Get kernel 67 libxsmm_dmmfunction kernel; 68 CeedHashIJKLMKey key = {B, C, J, t_mode, add}; 69 khint_t k = kh_get(m32, impl->lookup, key); 70 CeedHashGetValue(impl->lookup, k, kernel); 71 72 // Run kernel or fallback to default implementation 73 if (C != 1) 74 for (CeedInt a=0; a<A; a++) 75 kernel(&u[a*B*C], &t[0], &v[a*J*C], NULL, NULL, NULL); 76 else 77 CeedTensorContract_Xsmm_C1(contract, A, B, C, J, t, t_mode, add, u, v); 78 79 return CEED_ERROR_SUCCESS; 80 } 81 82 //------------------------------------------------------------------------------ 83 // Tensor Contract Destroy 84 //------------------------------------------------------------------------------ 85 static int CeedTensorContractDestroy_Xsmm(CeedTensorContract contract) { 86 int ierr; 87 CeedTensorContract_Xsmm *impl; 88 libxsmm_dmmfunction kernel; 89 90 ierr = CeedTensorContractGetData(contract, &impl); CeedChkBackend(ierr); 91 // Free kernels 92 kh_foreach_value(impl->lookup, kernel, libxsmm_release_kernel(&kernel)); 93 kh_destroy(m32, impl->lookup); 94 ierr = CeedFree(&impl); CeedChkBackend(ierr); 95 return CEED_ERROR_SUCCESS; 96 } 97 98 //------------------------------------------------------------------------------ 99 // Tensor Contract Create 100 //------------------------------------------------------------------------------ 101 int CeedTensorContractCreate_Xsmm(CeedBasis basis, 102 CeedTensorContract contract) { 103 int ierr; 104 Ceed ceed; 105 ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChkBackend(ierr); 106 CeedTensorContract_Xsmm *impl; 107 ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr); 108 109 // Setup kernels hash table 110 impl->lookup = kh_init(m32); 111 112 // Set up pointers to kernels 113 ierr = CeedBasisIsTensor(basis, &impl->is_tensor); CeedChkBackend(ierr); 114 if (impl->is_tensor) { 115 ierr = CeedBasisGetNumNodes1D(basis, &impl->P); CeedChkBackend(ierr); 116 ierr = CeedBasisGetNumQuadraturePoints1D(basis, &impl->Q); CeedChkBackend(ierr); 117 ierr = CeedBasisGetDimension(basis, &impl->dim); CeedChkBackend(ierr); 118 // Build all required kernels 119 for (CeedInt num_elem = 1; num_elem <= 8; num_elem+=7) 120 for (CeedInt add = 0; add <= 1; add++) 121 for (CeedInt t_mode = 0; t_mode <= 1; t_mode++) 122 for (CeedInt grad = 0; grad <=1; grad++) 123 for (CeedInt dim = 0; dim < impl->dim; dim++) { 124 const int flags = LIBXSMM_GEMM_FLAGS('N', t_mode ? 'T' : 'N'); 125 CeedInt B = grad ? impl->Q : (t_mode ? impl->Q : impl->P), 126 J = grad ? impl->Q : (t_mode ? impl->P : impl->Q), 127 C = num_elem*CeedIntPow(J, dim); 128 // Add key, kernel pair to hash table 129 CeedHashIJKLMKey key = {B, C, J, t_mode, add}; 130 int new_item; 131 khint_t k = kh_put(m32, impl->lookup, key, &new_item); 132 if (new_item) { 133 // Build kernel 134 CeedScalar alpha = 1.0, beta = 1.0; 135 if (!add) beta = 0.0; 136 libxsmm_dmmfunction kernel = libxsmm_dmmdispatch( 137 C, J, B, NULL, NULL, NULL, &alpha, &beta, &flags, NULL); 138 if (!kernel) 139 // LCOV_EXCL_START 140 return CeedError(ceed, CEED_ERROR_BACKEND, "LIBXSMM kernel failed to build."); 141 // LCOV_EXCL_STOP 142 // Add kernel to hash table 143 kh_value(impl->lookup, k) = kernel; 144 } 145 } 146 } else { 147 ierr = CeedBasisGetNumNodes(basis, &impl->P); CeedChkBackend(ierr); 148 ierr = CeedBasisGetNumQuadraturePoints(basis, &impl->Q); CeedChkBackend(ierr); 149 ierr = CeedBasisGetDimension(basis, &impl->dim); CeedChkBackend(ierr); 150 // Build all required kernels 151 for (CeedInt num_elem = 1; num_elem <= 8; num_elem+=7) 152 for (CeedInt add = 0; add <= 1; add++) 153 for (CeedInt t_mode = 0; t_mode <= 1; t_mode++) { 154 CeedInt gradstride = CeedIntMax(impl->dim-1, 1); 155 for (CeedInt grad = 1; grad <= impl->dim; grad+=gradstride) { 156 const int flags = LIBXSMM_GEMM_FLAGS('N', t_mode ? 'T' : 'N'); 157 CeedInt B = t_mode ? grad*impl->Q : impl->P, 158 J = t_mode ? impl->P : grad*impl->Q, 159 C = num_elem; 160 // Add key, kernel pair to hash table 161 CeedHashIJKLMKey key = {B, C, J, t_mode, add}; 162 int new_item; 163 khint_t k = kh_put(m32, impl->lookup, key, &new_item); 164 if (new_item) { 165 // Build kernel 166 CeedScalar alpha = 1.0, beta = 1.0; 167 if (!add) beta = 0.0; 168 libxsmm_dmmfunction kernel = libxsmm_dmmdispatch( 169 C, J, B, NULL, NULL, NULL, &alpha, &beta, &flags, NULL); 170 if (!kernel) 171 // LCOV_EXCL_START 172 return CeedError(ceed, CEED_ERROR_BACKEND, "LIBXSMM kernel failed to build."); 173 // LCOV_EXCL_STOP 174 // Add kernel to hash table 175 kh_value(impl->lookup, k) = kernel; 176 } 177 } 178 } 179 } 180 ierr = CeedTensorContractSetData(contract, impl); CeedChkBackend(ierr); 181 182 ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", 183 CeedTensorContractApply_Xsmm); CeedChkBackend(ierr); 184 ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy", 185 CeedTensorContractDestroy_Xsmm); CeedChkBackend(ierr); 186 187 return CEED_ERROR_SUCCESS; 188 } 189 //------------------------------------------------------------------------------ 190