1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 #include <ceed.h> 9 #include <ceed/backend.h> 10 #include <libxsmm.h> 11 12 #include "ceed-xsmm.h" 13 14 //------------------------------------------------------------------------------ 15 // Tensor Contract Apply 16 //------------------------------------------------------------------------------ 17 static int CeedTensorContractApply_Xsmm(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 18 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 19 Ceed ceed; 20 CeedCallBackend(CeedTensorContractGetCeed(contract, &ceed)); 21 22 if (C == 1) { 23 // Build or query the required kernel 24 const int flags_t = LIBXSMM_GEMM_FLAGS(!t_mode ? 'T' : 'N', 'N'); 25 const int flags_ab = (!add) ? LIBXSMM_GEMM_FLAG_BETA_0 : LIBXSMM_BASIC_GEMM_FLAG_NONE; 26 const int flags = (flags_t | flags_ab); 27 const libxsmm_gemm_shape gemm_shape = (CEED_SCALAR_TYPE == CEED_SCALAR_FP64) 28 ? libxsmm_create_gemm_shape(J, A, B, !t_mode ? B : J, B, J, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, 29 LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64) 30 : libxsmm_create_gemm_shape(J, A, B, !t_mode ? B : J, B, J, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, 31 LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32); 32 const libxsmm_gemmfunction kernel = libxsmm_dispatch_gemm_v2(gemm_shape, (libxsmm_bitfield)(flags), (libxsmm_bitfield)LIBXSMM_GEMM_PREFETCH_NONE); 33 CeedCheck(kernel, ceed, CEED_ERROR_BACKEND, "LIBXSMM kernel failed to build."); 34 35 // Run kernel 36 libxsmm_gemm_param gemm_param; 37 gemm_param.a.primary = (CeedScalar *)&t[0]; 38 gemm_param.b.primary = (CeedScalar *)&u[0]; 39 gemm_param.c.primary = (CeedScalar *)&v[0]; 40 kernel(&gemm_param); 41 } else { 42 // Build or query the required kernel 43 const int flags_t = LIBXSMM_GEMM_FLAGS('N', t_mode ? 'T' : 'N'); 44 const int flags_ab = (!add) ? LIBXSMM_GEMM_FLAG_BETA_0 : LIBXSMM_BASIC_GEMM_FLAG_NONE; 45 const int flags = (flags_t | flags_ab); 46 const libxsmm_gemm_shape gemm_shape = (CEED_SCALAR_TYPE == CEED_SCALAR_FP64) 47 ? libxsmm_create_gemm_shape(C, J, B, C, !t_mode ? B : J, C, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, 48 LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64) 49 : libxsmm_create_gemm_shape(C, J, B, C, !t_mode ? B : J, C, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, 50 LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32); 51 const libxsmm_gemmfunction kernel = libxsmm_dispatch_gemm_v2(gemm_shape, (libxsmm_bitfield)(flags), (libxsmm_bitfield)LIBXSMM_GEMM_PREFETCH_NONE); 52 CeedCheck(kernel, ceed, CEED_ERROR_BACKEND, "LIBXSMM kernel failed to build."); 53 54 // Run kernel 55 libxsmm_gemm_param gemm_param; 56 gemm_param.b.primary = (CeedScalar *)&t[0]; 57 for (CeedInt a = 0; a < A; a++) { 58 gemm_param.a.primary = (CeedScalar *)&u[a * B * C]; 59 gemm_param.c.primary = (CeedScalar *)&v[a * J * C]; 60 kernel(&gemm_param); 61 } 62 } 63 64 return CEED_ERROR_SUCCESS; 65 } 66 67 //------------------------------------------------------------------------------ 68 // Tensor Contract Create 69 //------------------------------------------------------------------------------ 70 int CeedTensorContractCreate_Xsmm(CeedBasis basis, CeedTensorContract contract) { 71 Ceed ceed; 72 CeedCallBackend(CeedTensorContractGetCeed(contract, &ceed)); 73 74 CeedCallBackend(CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", CeedTensorContractApply_Xsmm)); 75 76 return CEED_ERROR_SUCCESS; 77 } 78 79 //------------------------------------------------------------------------------ 80