1*4548da4eSSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2*4548da4eSSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*4548da4eSSebastian Grimberg // 4*4548da4eSSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5*4548da4eSSebastian Grimberg // 6*4548da4eSSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7*4548da4eSSebastian Grimberg 8*4548da4eSSebastian Grimberg #include <ceed.h> 9*4548da4eSSebastian Grimberg #include <ceed/backend.h> 10*4548da4eSSebastian Grimberg #include <libxsmm.h> 11*4548da4eSSebastian Grimberg 12*4548da4eSSebastian Grimberg #include "ceed-xsmm.h" 13*4548da4eSSebastian Grimberg 14*4548da4eSSebastian Grimberg //------------------------------------------------------------------------------ 15*4548da4eSSebastian Grimberg // Tensor Contract Apply 16*4548da4eSSebastian Grimberg //------------------------------------------------------------------------------ 17*4548da4eSSebastian Grimberg static int CeedTensorContractApply_Xsmm(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t, 18*4548da4eSSebastian Grimberg CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) { 19*4548da4eSSebastian Grimberg Ceed ceed; 20*4548da4eSSebastian Grimberg CeedCallBackend(CeedTensorContractGetCeed(contract, &ceed)); 21*4548da4eSSebastian Grimberg 22*4548da4eSSebastian Grimberg if (C == 1) { 23*4548da4eSSebastian Grimberg // Build or query the required kernel 24*4548da4eSSebastian Grimberg const int flags_t = LIBXSMM_GEMM_FLAGS(!t_mode ? 'T' : 'N', 'N'); 25*4548da4eSSebastian Grimberg const int flags_ab = (!add) ? LIBXSMM_GEMM_FLAG_BETA_0 : LIBXSMM_BASIC_GEMM_FLAG_NONE; 26*4548da4eSSebastian Grimberg const int flags = (flags_t | flags_ab); 27*4548da4eSSebastian Grimberg const libxsmm_gemm_shape gemm_shape = (CEED_SCALAR_TYPE == CEED_SCALAR_FP64) 28*4548da4eSSebastian Grimberg ? libxsmm_create_gemm_shape(J, A, B, !t_mode ? B : J, B, J, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, 29*4548da4eSSebastian Grimberg LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64) 30*4548da4eSSebastian Grimberg : libxsmm_create_gemm_shape(J, A, B, !t_mode ? B : J, B, J, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, 31*4548da4eSSebastian Grimberg LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32); 32*4548da4eSSebastian Grimberg const libxsmm_gemmfunction kernel = libxsmm_dispatch_gemm_v2(gemm_shape, (libxsmm_bitfield)(flags), (libxsmm_bitfield)LIBXSMM_GEMM_PREFETCH_NONE); 33*4548da4eSSebastian Grimberg CeedCheck(kernel, ceed, CEED_ERROR_BACKEND, "LIBXSMM kernel failed to build."); 34*4548da4eSSebastian Grimberg 35*4548da4eSSebastian Grimberg // Run kernel 36*4548da4eSSebastian Grimberg libxsmm_gemm_param gemm_param; 37*4548da4eSSebastian Grimberg gemm_param.a.primary = (CeedScalar *)&t[0]; 38*4548da4eSSebastian Grimberg gemm_param.b.primary = (CeedScalar *)&u[0]; 39*4548da4eSSebastian Grimberg gemm_param.c.primary = (CeedScalar *)&v[0]; 40*4548da4eSSebastian Grimberg kernel(&gemm_param); 41*4548da4eSSebastian Grimberg } else { 42*4548da4eSSebastian Grimberg // Build or query the required kernel 43*4548da4eSSebastian Grimberg const int flags_t = LIBXSMM_GEMM_FLAGS('N', t_mode ? 'T' : 'N'); 44*4548da4eSSebastian Grimberg const int flags_ab = (!add) ? LIBXSMM_GEMM_FLAG_BETA_0 : LIBXSMM_BASIC_GEMM_FLAG_NONE; 45*4548da4eSSebastian Grimberg const int flags = (flags_t | flags_ab); 46*4548da4eSSebastian Grimberg const libxsmm_gemm_shape gemm_shape = (CEED_SCALAR_TYPE == CEED_SCALAR_FP64) 47*4548da4eSSebastian Grimberg ? libxsmm_create_gemm_shape(C, J, B, C, !t_mode ? B : J, C, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64, 48*4548da4eSSebastian Grimberg LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64) 49*4548da4eSSebastian Grimberg : libxsmm_create_gemm_shape(C, J, B, C, !t_mode ? B : J, C, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32, 50*4548da4eSSebastian Grimberg LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32); 51*4548da4eSSebastian Grimberg const libxsmm_gemmfunction kernel = libxsmm_dispatch_gemm_v2(gemm_shape, (libxsmm_bitfield)(flags), (libxsmm_bitfield)LIBXSMM_GEMM_PREFETCH_NONE); 52*4548da4eSSebastian Grimberg CeedCheck(kernel, ceed, CEED_ERROR_BACKEND, "LIBXSMM kernel failed to build."); 53*4548da4eSSebastian Grimberg 54*4548da4eSSebastian Grimberg // Run kernel 55*4548da4eSSebastian Grimberg libxsmm_gemm_param gemm_param; 56*4548da4eSSebastian Grimberg gemm_param.b.primary = (CeedScalar *)&t[0]; 57*4548da4eSSebastian Grimberg for (CeedInt a = 0; a < A; a++) { 58*4548da4eSSebastian Grimberg gemm_param.a.primary = (CeedScalar *)&u[a * B * C]; 59*4548da4eSSebastian Grimberg gemm_param.c.primary = (CeedScalar *)&v[a * J * C]; 60*4548da4eSSebastian Grimberg kernel(&gemm_param); 61*4548da4eSSebastian Grimberg } 62*4548da4eSSebastian Grimberg } 63*4548da4eSSebastian Grimberg 64*4548da4eSSebastian Grimberg return CEED_ERROR_SUCCESS; 65*4548da4eSSebastian Grimberg } 66*4548da4eSSebastian Grimberg 67*4548da4eSSebastian Grimberg //------------------------------------------------------------------------------ 68*4548da4eSSebastian Grimberg // Tensor Contract Create 69*4548da4eSSebastian Grimberg //------------------------------------------------------------------------------ 70*4548da4eSSebastian Grimberg int CeedTensorContractCreate_Xsmm(CeedBasis basis, CeedTensorContract contract) { 71*4548da4eSSebastian Grimberg Ceed ceed; 72*4548da4eSSebastian Grimberg CeedCallBackend(CeedTensorContractGetCeed(contract, &ceed)); 73*4548da4eSSebastian Grimberg 74*4548da4eSSebastian Grimberg CeedCallBackend(CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", CeedTensorContractApply_Xsmm)); 75*4548da4eSSebastian Grimberg 76*4548da4eSSebastian Grimberg return CEED_ERROR_SUCCESS; 77*4548da4eSSebastian Grimberg } 78*4548da4eSSebastian Grimberg 79*4548da4eSSebastian Grimberg //------------------------------------------------------------------------------ 80