xref: /libCEED/backends/xsmm/ceed-xsmm-tensor.c (revision c71e1dcdcfce633cdb19fa81aa6735b006eb809d)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <string.h>
18 #include <libxsmm.h>
19 #include "ceed-xsmm.h"
20 
21 // Blocked Tensor Contract
22 static int CeedTensorContract_Xsmm_Blocked(CeedTensorContract contract,
23     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
24     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
25     CeedScalar *restrict v) {
26   CeedScalar alpha = 1.0, beta = 1.0;
27   char transu = 'N', transt = 'N';
28   if (tmode == CEED_TRANSPOSE)
29     transt = 'T';
30 
31   if (!Add)
32     beta = 0.0;
33 
34   for (CeedInt a=0; a<A; a++)
35     // libXSMM GEMM
36     libxsmm_dgemm(&transu, &transt, &C, &J, &B,
37                   &alpha, &u[a*B*C], NULL, &t[0], NULL,
38                   &beta, &v[a*J*C], NULL);
39   return 0;
40 }
41 
42 // Serial Tensor Contact
43 static int CeedTensorContract_Xsmm_Serial(CeedTensorContract contract,
44     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
45     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
46     CeedScalar *restrict v) {
47   CeedScalar alpha = 1.0, beta = 1.0;
48   char transu = 'N', transt = 'N';
49   if ((tmode == CEED_TRANSPOSE && C != 1)
50       || (tmode == CEED_NOTRANSPOSE && C == 1))
51     transt = 'T';
52 
53   if (!Add)
54     beta = 0.0;
55 
56   if (C != 1)
57     for (CeedInt a=0; a<A; a++)
58       // libXSMM GEMM
59       libxsmm_dgemm(&transu, &transt, &C, &J, &B,
60                     &alpha, &u[a*B*C], NULL, &t[0], NULL,
61                     &beta, &v[a*J*C], NULL);
62   else
63     // libXSMM GEMM
64     libxsmm_dgemm(&transt, &transu, &J, &A, &B,
65                   &alpha, &t[0], NULL, &u[0], NULL,
66                   &beta, &v[0], NULL);
67 
68   return 0;
69 }
70 
71 // Switch for Tensor Contract
72 static int CeedTensorContractApply_Xsmm(CeedTensorContract contract, CeedInt A,
73                                         CeedInt B, CeedInt C, CeedInt J,
74                                         const CeedScalar *restrict t,
75                                         CeedTransposeMode tmode,
76                                         const CeedInt Add,
77                                         const CeedScalar *restrict u,
78                                         CeedScalar *restrict v) {
79   CeedInt blksize = 8;
80 
81   if (C % blksize)
82     CeedTensorContract_Xsmm_Serial(contract, A, B, C, J, t, tmode, Add, u, v);
83   else
84     CeedTensorContract_Xsmm_Blocked(contract, A, B, C, J, t, tmode, Add, u, v);
85 
86   return 0;
87 }
88 
89 static int CeedTensorContractDestroy_Xsmm(CeedTensorContract contract) {
90   return 0;
91 }
92 
93 int CeedTensorContractCreate_Xsmm(CeedBasis basis, CeedTensorContract contract) {
94   int ierr;
95   Ceed ceed;
96   ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChk(ierr);
97 
98   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply",
99                                 CeedTensorContractApply_Xsmm); CeedChk(ierr);
100   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy",
101                                 CeedTensorContractDestroy_Xsmm); CeedChk(ierr);
102 
103   return 0;
104 }
105