xref: /libCEED/rust/libceed-sys/c-src/backends/xsmm/ceed-xsmm-tensor.c (revision 4548da4e4ef44dc0f2704ad6d48ac0ca4a16bc83)
1*4548da4eSSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*4548da4eSSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*4548da4eSSebastian Grimberg //
4*4548da4eSSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5*4548da4eSSebastian Grimberg //
6*4548da4eSSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7*4548da4eSSebastian Grimberg 
8*4548da4eSSebastian Grimberg #include <ceed.h>
9*4548da4eSSebastian Grimberg #include <ceed/backend.h>
10*4548da4eSSebastian Grimberg #include <libxsmm.h>
11*4548da4eSSebastian Grimberg 
12*4548da4eSSebastian Grimberg #include "ceed-xsmm.h"
13*4548da4eSSebastian Grimberg 
14*4548da4eSSebastian Grimberg //------------------------------------------------------------------------------
15*4548da4eSSebastian Grimberg // Tensor Contract Apply
16*4548da4eSSebastian Grimberg //------------------------------------------------------------------------------
17*4548da4eSSebastian Grimberg static int CeedTensorContractApply_Xsmm(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
18*4548da4eSSebastian Grimberg                                         CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
19*4548da4eSSebastian Grimberg   Ceed ceed;
20*4548da4eSSebastian Grimberg   CeedCallBackend(CeedTensorContractGetCeed(contract, &ceed));
21*4548da4eSSebastian Grimberg 
22*4548da4eSSebastian Grimberg   if (C == 1) {
23*4548da4eSSebastian Grimberg     // Build or query the required kernel
24*4548da4eSSebastian Grimberg     const int                  flags_t    = LIBXSMM_GEMM_FLAGS(!t_mode ? 'T' : 'N', 'N');
25*4548da4eSSebastian Grimberg     const int                  flags_ab   = (!add) ? LIBXSMM_GEMM_FLAG_BETA_0 : LIBXSMM_BASIC_GEMM_FLAG_NONE;
26*4548da4eSSebastian Grimberg     const int                  flags      = (flags_t | flags_ab);
27*4548da4eSSebastian Grimberg     const libxsmm_gemm_shape   gemm_shape = (CEED_SCALAR_TYPE == CEED_SCALAR_FP64)
28*4548da4eSSebastian Grimberg                                                 ? libxsmm_create_gemm_shape(J, A, B, !t_mode ? B : J, B, J, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64,
29*4548da4eSSebastian Grimberg                                                                             LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64)
30*4548da4eSSebastian Grimberg                                                 : libxsmm_create_gemm_shape(J, A, B, !t_mode ? B : J, B, J, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32,
31*4548da4eSSebastian Grimberg                                                                             LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32);
32*4548da4eSSebastian Grimberg     const libxsmm_gemmfunction kernel = libxsmm_dispatch_gemm_v2(gemm_shape, (libxsmm_bitfield)(flags), (libxsmm_bitfield)LIBXSMM_GEMM_PREFETCH_NONE);
33*4548da4eSSebastian Grimberg     CeedCheck(kernel, ceed, CEED_ERROR_BACKEND, "LIBXSMM kernel failed to build.");
34*4548da4eSSebastian Grimberg 
35*4548da4eSSebastian Grimberg     // Run kernel
36*4548da4eSSebastian Grimberg     libxsmm_gemm_param gemm_param;
37*4548da4eSSebastian Grimberg     gemm_param.a.primary = (CeedScalar *)&t[0];
38*4548da4eSSebastian Grimberg     gemm_param.b.primary = (CeedScalar *)&u[0];
39*4548da4eSSebastian Grimberg     gemm_param.c.primary = (CeedScalar *)&v[0];
40*4548da4eSSebastian Grimberg     kernel(&gemm_param);
41*4548da4eSSebastian Grimberg   } else {
42*4548da4eSSebastian Grimberg     // Build or query the required kernel
43*4548da4eSSebastian Grimberg     const int                  flags_t    = LIBXSMM_GEMM_FLAGS('N', t_mode ? 'T' : 'N');
44*4548da4eSSebastian Grimberg     const int                  flags_ab   = (!add) ? LIBXSMM_GEMM_FLAG_BETA_0 : LIBXSMM_BASIC_GEMM_FLAG_NONE;
45*4548da4eSSebastian Grimberg     const int                  flags      = (flags_t | flags_ab);
46*4548da4eSSebastian Grimberg     const libxsmm_gemm_shape   gemm_shape = (CEED_SCALAR_TYPE == CEED_SCALAR_FP64)
47*4548da4eSSebastian Grimberg                                                 ? libxsmm_create_gemm_shape(C, J, B, C, !t_mode ? B : J, C, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64,
48*4548da4eSSebastian Grimberg                                                                             LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64)
49*4548da4eSSebastian Grimberg                                                 : libxsmm_create_gemm_shape(C, J, B, C, !t_mode ? B : J, C, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32,
50*4548da4eSSebastian Grimberg                                                                             LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32);
51*4548da4eSSebastian Grimberg     const libxsmm_gemmfunction kernel = libxsmm_dispatch_gemm_v2(gemm_shape, (libxsmm_bitfield)(flags), (libxsmm_bitfield)LIBXSMM_GEMM_PREFETCH_NONE);
52*4548da4eSSebastian Grimberg     CeedCheck(kernel, ceed, CEED_ERROR_BACKEND, "LIBXSMM kernel failed to build.");
53*4548da4eSSebastian Grimberg 
54*4548da4eSSebastian Grimberg     // Run kernel
55*4548da4eSSebastian Grimberg     libxsmm_gemm_param gemm_param;
56*4548da4eSSebastian Grimberg     gemm_param.b.primary = (CeedScalar *)&t[0];
57*4548da4eSSebastian Grimberg     for (CeedInt a = 0; a < A; a++) {
58*4548da4eSSebastian Grimberg       gemm_param.a.primary = (CeedScalar *)&u[a * B * C];
59*4548da4eSSebastian Grimberg       gemm_param.c.primary = (CeedScalar *)&v[a * J * C];
60*4548da4eSSebastian Grimberg       kernel(&gemm_param);
61*4548da4eSSebastian Grimberg     }
62*4548da4eSSebastian Grimberg   }
63*4548da4eSSebastian Grimberg 
64*4548da4eSSebastian Grimberg   return CEED_ERROR_SUCCESS;
65*4548da4eSSebastian Grimberg }
66*4548da4eSSebastian Grimberg 
67*4548da4eSSebastian Grimberg //------------------------------------------------------------------------------
68*4548da4eSSebastian Grimberg // Tensor Contract Create
69*4548da4eSSebastian Grimberg //------------------------------------------------------------------------------
70*4548da4eSSebastian Grimberg int CeedTensorContractCreate_Xsmm(CeedBasis basis, CeedTensorContract contract) {
71*4548da4eSSebastian Grimberg   Ceed ceed;
72*4548da4eSSebastian Grimberg   CeedCallBackend(CeedTensorContractGetCeed(contract, &ceed));
73*4548da4eSSebastian Grimberg 
74*4548da4eSSebastian Grimberg   CeedCallBackend(CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", CeedTensorContractApply_Xsmm));
75*4548da4eSSebastian Grimberg 
76*4548da4eSSebastian Grimberg   return CEED_ERROR_SUCCESS;
77*4548da4eSSebastian Grimberg }
78*4548da4eSSebastian Grimberg 
79*4548da4eSSebastian Grimberg //------------------------------------------------------------------------------
80