xref: /libCEED/backends/xsmm/ceed-xsmm-tensor.c (revision 7113573b6efd54558bb98b919dff5d6d8ffcff54)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <libxsmm.h>
11 
12 #include "ceed-xsmm.h"
13 
14 //------------------------------------------------------------------------------
15 // Tensor Contract Apply
16 //------------------------------------------------------------------------------
17 static int CeedTensorContractApply_Xsmm(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
18                                         CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
19   Ceed ceed;
20   CeedCallBackend(CeedTensorContractGetCeed(contract, &ceed));
21 
22   if (C == 1) {
23     // Build or query the required kernel
24     const int                  flags_t    = LIBXSMM_GEMM_FLAGS(!t_mode ? 'T' : 'N', 'N');
25     const int                  flags_ab   = (!add) ? LIBXSMM_GEMM_FLAG_BETA_0 : LIBXSMM_BASIC_GEMM_FLAG_NONE;
26     const int                  flags      = (flags_t | flags_ab);
27     const libxsmm_gemm_shape   gemm_shape = (CEED_SCALAR_TYPE == CEED_SCALAR_FP64)
28                                                 ? libxsmm_create_gemm_shape(J, A, B, !t_mode ? B : J, B, J, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64,
29                                                                             LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64)
30                                                 : libxsmm_create_gemm_shape(J, A, B, !t_mode ? B : J, B, J, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32,
31                                                                             LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32);
32     const libxsmm_gemmfunction kernel = libxsmm_dispatch_gemm_v2(gemm_shape, (libxsmm_bitfield)(flags), (libxsmm_bitfield)LIBXSMM_GEMM_PREFETCH_NONE);
33     CeedCheck(kernel, ceed, CEED_ERROR_BACKEND, "LIBXSMM kernel failed to build.");
34 
35     // Run kernel
36     libxsmm_gemm_param gemm_param;
37     gemm_param.a.primary = (CeedScalar *)&t[0];
38     gemm_param.b.primary = (CeedScalar *)&u[0];
39     gemm_param.c.primary = (CeedScalar *)&v[0];
40     kernel(&gemm_param);
41   } else {
42     // Build or query the required kernel
43     const int                  flags_t    = LIBXSMM_GEMM_FLAGS('N', t_mode ? 'T' : 'N');
44     const int                  flags_ab   = (!add) ? LIBXSMM_GEMM_FLAG_BETA_0 : LIBXSMM_BASIC_GEMM_FLAG_NONE;
45     const int                  flags      = (flags_t | flags_ab);
46     const libxsmm_gemm_shape   gemm_shape = (CEED_SCALAR_TYPE == CEED_SCALAR_FP64)
47                                                 ? libxsmm_create_gemm_shape(C, J, B, C, !t_mode ? B : J, C, LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64,
48                                                                             LIBXSMM_DATATYPE_F64, LIBXSMM_DATATYPE_F64)
49                                                 : libxsmm_create_gemm_shape(C, J, B, C, !t_mode ? B : J, C, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32,
50                                                                             LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32);
51     const libxsmm_gemmfunction kernel = libxsmm_dispatch_gemm_v2(gemm_shape, (libxsmm_bitfield)(flags), (libxsmm_bitfield)LIBXSMM_GEMM_PREFETCH_NONE);
52     CeedCheck(kernel, ceed, CEED_ERROR_BACKEND, "LIBXSMM kernel failed to build.");
53 
54     // Run kernel
55     libxsmm_gemm_param gemm_param;
56     gemm_param.b.primary = (CeedScalar *)&t[0];
57     for (CeedInt a = 0; a < A; a++) {
58       gemm_param.a.primary = (CeedScalar *)&u[a * B * C];
59       gemm_param.c.primary = (CeedScalar *)&v[a * J * C];
60       kernel(&gemm_param);
61     }
62   }
63 
64   return CEED_ERROR_SUCCESS;
65 }
66 
67 //------------------------------------------------------------------------------
68 // Tensor Contract Create
69 //------------------------------------------------------------------------------
70 int CeedTensorContractCreate_Xsmm(CeedBasis basis, CeedTensorContract contract) {
71   Ceed ceed;
72   CeedCallBackend(CeedTensorContractGetCeed(contract, &ceed));
73 
74   CeedCallBackend(CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply", CeedTensorContractApply_Xsmm));
75 
76   return CEED_ERROR_SUCCESS;
77 }
78 
79 //------------------------------------------------------------------------------
80