xref: /libCEED/rust/libceed-sys/c-src/backends/xsmm/ceed-xsmm-tensor.c (revision d4cc18453651bd0f94c1a2e078b2646a92dafdcc)
1*9ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
24548da4eSSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
34548da4eSSebastian Grimberg //
44548da4eSSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
54548da4eSSebastian Grimberg //
64548da4eSSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
74548da4eSSebastian Grimberg 
84548da4eSSebastian Grimberg #include <ceed.h>
94548da4eSSebastian Grimberg #include <ceed/backend.h>
104548da4eSSebastian Grimberg #include <libxsmm.h>
114548da4eSSebastian Grimberg 
124548da4eSSebastian Grimberg #include "ceed-xsmm.h"
134548da4eSSebastian Grimberg 
144548da4eSSebastian Grimberg //------------------------------------------------------------------------------
154548da4eSSebastian Grimberg // Tensor Contract Apply
164548da4eSSebastian Grimberg //------------------------------------------------------------------------------
CeedTensorContractApply_Xsmm(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v)174548da4eSSebastian Grimberg static int CeedTensorContractApply_Xsmm(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
184548da4eSSebastian Grimberg                                         CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
194548da4eSSebastian Grimberg   if (C == 1) {
204548da4eSSebastian Grimberg     // Build or query the required kernel
214548da4eSSebastian Grimberg     const int                  flags_t    = LIBXSMM_GEMM_FLAGS(!t_mode ? 'T' : 'N', 'N');
224548da4eSSebastian Grimberg     const int                  flags_ab   = (!add) ? LIBXSMM_GEMM_FLAG_BETA_0 : LIBXSMM_BASIC_GEMM_FLAG_NONE;
234548da4eSSebastian Grimberg     const int                  flags      = (flags_t | flags_ab);
244548da4eSSebastian Grimberg     const libxsmm_gemm_shape   gemm_shape = (CEED_SCALAR_TYPE == CEED_SCALAR_FP64)
254548da4eSSebastian Grimberg                                                 ? libxsmm_create_gemm_shape(J, A, B, !t_mode ? B : J, B, J, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64,
264548da4eSSebastian Grimberg                                                                             LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64)
274548da4eSSebastian Grimberg                                                 : libxsmm_create_gemm_shape(J, A, B, !t_mode ? B : J, B, J, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32,
284548da4eSSebastian Grimberg                                                                             LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32);
29f220c67cSJeremy L Thompson     const libxsmm_gemmfunction kernel = libxsmm_dispatch_gemm(gemm_shape, (libxsmm_bitfield)(flags), (libxsmm_bitfield)LIBXSMM_GEMM_PREFETCH_NONE);
30ad70ee2cSJeremy L Thompson     libxsmm_gemm_param         gemm_param;
31ad70ee2cSJeremy L Thompson 
329bc66399SJeremy L Thompson     CeedCheck(kernel, CeedTensorContractReturnCeed(contract), CEED_ERROR_BACKEND, "LIBXSMM kernel failed to build.");
334548da4eSSebastian Grimberg 
344548da4eSSebastian Grimberg     // Run kernel
354548da4eSSebastian Grimberg     gemm_param.a.primary = (CeedScalar *)&t[0];
364548da4eSSebastian Grimberg     gemm_param.b.primary = (CeedScalar *)&u[0];
374548da4eSSebastian Grimberg     gemm_param.c.primary = (CeedScalar *)&v[0];
384548da4eSSebastian Grimberg     kernel(&gemm_param);
394548da4eSSebastian Grimberg   } else {
404548da4eSSebastian Grimberg     // Build or query the required kernel
414548da4eSSebastian Grimberg     const int                  flags_t    = LIBXSMM_GEMM_FLAGS('N', t_mode ? 'T' : 'N');
424548da4eSSebastian Grimberg     const int                  flags_ab   = (!add) ? LIBXSMM_GEMM_FLAG_BETA_0 : LIBXSMM_BASIC_GEMM_FLAG_NONE;
434548da4eSSebastian Grimberg     const int                  flags      = (flags_t | flags_ab);
444548da4eSSebastian Grimberg     const libxsmm_gemm_shape   gemm_shape = (CEED_SCALAR_TYPE == CEED_SCALAR_FP64)
454548da4eSSebastian Grimberg                                                 ? libxsmm_create_gemm_shape(C, J, B, C, !t_mode ? B : J, C, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64,
464548da4eSSebastian Grimberg                                                                             LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64)
474548da4eSSebastian Grimberg                                                 : libxsmm_create_gemm_shape(C, J, B, C, !t_mode ? B : J, C, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32,
484548da4eSSebastian Grimberg                                                                             LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32);
49f220c67cSJeremy L Thompson     const libxsmm_gemmfunction kernel = libxsmm_dispatch_gemm(gemm_shape, (libxsmm_bitfield)(flags), (libxsmm_bitfield)LIBXSMM_GEMM_PREFETCH_NONE);
50ad70ee2cSJeremy L Thompson     libxsmm_gemm_param         gemm_param;
51ad70ee2cSJeremy L Thompson 
529bc66399SJeremy L Thompson     CeedCheck(kernel, CeedTensorContractReturnCeed(contract), CEED_ERROR_BACKEND, "LIBXSMM kernel failed to build.");
534548da4eSSebastian Grimberg 
544548da4eSSebastian Grimberg     // Run kernel
554548da4eSSebastian Grimberg     gemm_param.b.primary = (CeedScalar *)&t[0];
564548da4eSSebastian Grimberg     for (CeedInt a = 0; a < A; a++) {
574548da4eSSebastian Grimberg       gemm_param.a.primary = (CeedScalar *)&u[a * B * C];
584548da4eSSebastian Grimberg       gemm_param.c.primary = (CeedScalar *)&v[a * J * C];
594548da4eSSebastian Grimberg       kernel(&gemm_param);
604548da4eSSebastian Grimberg     }
614548da4eSSebastian Grimberg   }
624548da4eSSebastian Grimberg   return CEED_ERROR_SUCCESS;
634548da4eSSebastian Grimberg }
644548da4eSSebastian Grimberg 
654548da4eSSebastian Grimberg //------------------------------------------------------------------------------
664548da4eSSebastian Grimberg // Tensor Contract Create
674548da4eSSebastian Grimberg //------------------------------------------------------------------------------
CeedTensorContractCreate_Xsmm(CeedTensorContract contract)68a71faab1SSebastian Grimberg int CeedTensorContractCreate_Xsmm(CeedTensorContract contract) {
696e536b99SJeremy L Thompson   CeedCallBackend(CeedSetBackendFunction(CeedTensorContractReturnCeed(contract), "TensorContract", contract, "Apply", CeedTensorContractApply_Xsmm));
704548da4eSSebastian Grimberg   return CEED_ERROR_SUCCESS;
714548da4eSSebastian Grimberg }
724548da4eSSebastian Grimberg 
734548da4eSSebastian Grimberg //------------------------------------------------------------------------------
74