1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED: http://github.com/ceed
7
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <libxsmm.h>
11
12 #include "ceed-xsmm.h"
13
14 //------------------------------------------------------------------------------
15 // Tensor Contract Apply
16 //------------------------------------------------------------------------------
CeedTensorContractApply_Xsmm(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v)17 static int CeedTensorContractApply_Xsmm(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
18 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
19 if (C == 1) {
20 // Build or query the required kernel
21 const int flags_t = LIBXSMM_GEMM_FLAGS(!t_mode ? 'T' : 'N', 'N');
22 const int flags_ab = (!add) ? LIBXSMM_GEMM_FLAG_BETA_0 : LIBXSMM_BASIC_GEMM_FLAG_NONE;
23 const int flags = (flags_t | flags_ab);
24 const libxsmm_gemm_shape gemm_shape = (CEED_SCALAR_TYPE == CEED_SCALAR_FP64)
25 ? libxsmm_create_gemm_shape(J, A, B, !t_mode ? B : J, B, J, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64,
26 LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64)
27 : libxsmm_create_gemm_shape(J, A, B, !t_mode ? B : J, B, J, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32,
28 LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32);
29 const libxsmm_gemmfunction kernel = libxsmm_dispatch_gemm(gemm_shape, (libxsmm_bitfield)(flags), (libxsmm_bitfield)LIBXSMM_GEMM_PREFETCH_NONE);
30 libxsmm_gemm_param gemm_param;
31
32 CeedCheck(kernel, CeedTensorContractReturnCeed(contract), CEED_ERROR_BACKEND, "LIBXSMM kernel failed to build.");
33
34 // Run kernel
35 gemm_param.a.primary = (CeedScalar *)&t[0];
36 gemm_param.b.primary = (CeedScalar *)&u[0];
37 gemm_param.c.primary = (CeedScalar *)&v[0];
38 kernel(&gemm_param);
39 } else {
40 // Build or query the required kernel
41 const int flags_t = LIBXSMM_GEMM_FLAGS('N', t_mode ? 'T' : 'N');
42 const int flags_ab = (!add) ? LIBXSMM_GEMM_FLAG_BETA_0 : LIBXSMM_BASIC_GEMM_FLAG_NONE;
43 const int flags = (flags_t | flags_ab);
44 const libxsmm_gemm_shape gemm_shape = (CEED_SCALAR_TYPE == CEED_SCALAR_FP64)
45 ? libxsmm_create_gemm_shape(C, J, B, C, !t_mode ? B : J, C, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64,
46 LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64)
47 : libxsmm_create_gemm_shape(C, J, B, C, !t_mode ? B : J, C, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32,
48 LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32);
49 const libxsmm_gemmfunction kernel = libxsmm_dispatch_gemm(gemm_shape, (libxsmm_bitfield)(flags), (libxsmm_bitfield)LIBXSMM_GEMM_PREFETCH_NONE);
50 libxsmm_gemm_param gemm_param;
51
52 CeedCheck(kernel, CeedTensorContractReturnCeed(contract), CEED_ERROR_BACKEND, "LIBXSMM kernel failed to build.");
53
54 // Run kernel
55 gemm_param.b.primary = (CeedScalar *)&t[0];
56 for (CeedInt a = 0; a < A; a++) {
57 gemm_param.a.primary = (CeedScalar *)&u[a * B * C];
58 gemm_param.c.primary = (CeedScalar *)&v[a * J * C];
59 kernel(&gemm_param);
60 }
61 }
62 return CEED_ERROR_SUCCESS;
63 }
64
65 //------------------------------------------------------------------------------
66 // Tensor Contract Create
67 //------------------------------------------------------------------------------
CeedTensorContractCreate_Xsmm(CeedTensorContract contract)68 int CeedTensorContractCreate_Xsmm(CeedTensorContract contract) {
69 CeedCallBackend(CeedSetBackendFunction(CeedTensorContractReturnCeed(contract), "TensorContract", contract, "Apply", CeedTensorContractApply_Xsmm));
70 return CEED_ERROR_SUCCESS;
71 }
72
73 //------------------------------------------------------------------------------
74