xref: /libCEED/backends/xsmm/ceed-xsmm-tensor.c (revision cb32e2e7f026784d97a57f1901677e9727def907)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include "ceed-xsmm.h"
18 
19 // Utility functions for index in pointer array
20 int CeedGetXsmmInd_Tensor(CeedInt nelem, CeedInt add, CeedTransposeMode tmode,
21                           CeedInt B, CeedInt C, CeedInt J, CeedInt currdim,
22                           CeedInt dim) {
23   return (nelem == 8 ? 1:0)*4*2*dim + (add ? 1:0)*4*dim +
24          (tmode ? 1:0)*2*dim + (B == J ? 1:0)*dim + currdim;
25 }
26 
27 int CeedGetXsmmInd_NonTensor(CeedInt add, CeedInt P, CeedInt Q, CeedInt B,
28                              CeedInt C, CeedInt J) {
29   return (C == 8 ? 1:0)*4*2 + (add ? 1:0)*4 +
30          (B == P ? (J == Q ? 0:1) : (B == Q ? 2:3));
31 }
32 
33 // Default Tensor Contact
34 static int CeedTensorContract_Xsmm_C1(CeedTensorContract contract,
35                                       CeedInt A, CeedInt B, CeedInt C,
36                                       CeedInt J, const CeedScalar *restrict t,
37                                       CeedTransposeMode tmode,
38                                       const CeedInt Add,
39                                       const CeedScalar *restrict u,
40                                       CeedScalar *restrict v) {
41   CeedScalar alpha = 1.0, beta = 1.0;
42   char transu = 'N', transt = 'N';
43   if ((tmode == CEED_TRANSPOSE && C != 1)
44       || (tmode == CEED_NOTRANSPOSE && C == 1))
45     transt = 'T';
46 
47   if (!Add)
48     beta = 0.0;
49 
50   // libXSMM GEMM
51   libxsmm_dgemm(&transt, &transu, &J, &A, &B,
52                 &alpha, &t[0], NULL, &u[0], NULL,
53                 &beta, &v[0], NULL);
54 
55   return 0;
56 }
57 
58 // Switch for Tensor Contract
59 static int CeedTensorContractApply_Xsmm(CeedTensorContract contract, CeedInt A,
60                                         CeedInt B, CeedInt C, CeedInt J,
61                                         const CeedScalar *restrict t,
62                                         CeedTransposeMode tmode,
63                                         const CeedInt add,
64                                         const CeedScalar *restrict u,
65                                         CeedScalar *restrict v) {
66   int ierr;
67   CeedInt blksize = 8, ind, nelem;
68   CeedTensorContract_Xsmm *impl;
69   ierr = CeedTensorContractGetData(contract, (void *)&impl); CeedChk(ierr);
70 
71   // Get nelem and current dim
72   CeedScalar currdim = log(C/blksize) / log(J);
73   if (!(C % blksize) && currdim - (int)currdim < 1e-15) {
74     nelem = blksize;
75   } else {
76     nelem = 1;
77     currdim = log(C) / log(J);
78   }
79 
80   // Get kernel index
81   if (impl->tensorbasis)
82     ind = CeedGetXsmmInd_Tensor(nelem, add, tmode==CEED_TRANSPOSE?1:0, B, C, J,
83                                 (CeedInt)currdim, impl->dim);
84   else
85     ind = CeedGetXsmmInd_NonTensor(add, impl->P, impl->Q, B, C, J);
86 
87   // Run kernel or fallback to default implementation
88   if (C != 1)
89     for (CeedInt a=0; a<A; a++)
90       impl->kernels[ind](&u[a*B*C], &t[0], &v[a*J*C], NULL, NULL, NULL);
91   else
92     CeedTensorContract_Xsmm_C1(contract, A, B, C, J, t, tmode, add, u, v);
93 
94   return 0;
95 }
96 
97 static int CeedTensorContractDestroy_Xsmm(CeedTensorContract contract) {
98   int ierr;
99   CeedTensorContract_Xsmm *impl;
100   ierr = CeedTensorContractGetData(contract, (void *)&impl); CeedChk(ierr);
101   ierr = CeedFree(&impl->kernels); CeedChk(ierr);
102   ierr = CeedFree(&impl); CeedChk(ierr);
103 
104   return 0;
105 }
106 
107 int CeedTensorContractCreate_Xsmm(CeedBasis basis,
108                                   CeedTensorContract contract) {
109   int ierr;
110   Ceed ceed;
111   ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChk(ierr);
112   CeedTensorContract_Xsmm *impl;
113   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
114 
115   // Set up pointers to kernels
116   ierr = CeedBasisGetTensorStatus(basis, &impl->tensorbasis); CeedChk(ierr);
117   if (impl->tensorbasis) {
118     ierr = CeedBasisGetNumNodes1D(basis, &impl->P); CeedChk(ierr);
119     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &impl->Q); CeedChk(ierr);
120     ierr = CeedBasisGetDimension(basis, &impl->dim); CeedChk(ierr);
121     // Set up kernel pointer array
122     impl->numkernels = 2*2*4*impl->dim;
123     ierr = CeedCalloc(impl->numkernels, &impl->kernels); CeedChk(ierr);
124     for (CeedInt nelem = 1; nelem <= 8; nelem+=7)
125       for (CeedInt add = 0; add <= 1; add++)
126         for (CeedInt tmode = 0; tmode <= 1; tmode++)
127           for (CeedInt grad = 0; grad <=1; grad++)
128             for (CeedInt dim = 0; dim < impl->dim; dim++) {
129               const int flags = LIBXSMM_GEMM_FLAGS('N', tmode ? 'T' : 'N');
130               CeedInt B = grad ? impl->Q : (tmode ? impl->Q : impl->P),
131                       J = grad ? impl->Q : (tmode ? impl->P : impl->Q),
132                       C = nelem*CeedIntPow(J, dim);
133               int ind = CeedGetXsmmInd_Tensor(nelem, add, tmode, B, C, J, dim,
134                                               impl->dim);
135               CeedScalar alpha = 1.0, beta = 1.0;
136               if (!add) beta = 0.0;
137               impl->kernels[ind] = libxsmm_dmmdispatch(C, J, B,
138                                    NULL, NULL, NULL, &alpha,
139                                    &beta, &flags, NULL);
140               if (!impl->kernels[ind])
141                 // LCOV_EXCL_START
142                 return CeedError(ceed, 1, "LIBXSMM kernel failed to build.");
143               // LCOV_EXCL_STOP
144             }
145   } else {
146     ierr = CeedBasisGetNumNodes(basis, &impl->P); CeedChk(ierr);
147     ierr = CeedBasisGetNumQuadraturePoints(basis, &impl->Q); CeedChk(ierr);
148     ierr = CeedBasisGetDimension(basis, &impl->dim); CeedChk(ierr);
149     // Set up kernel pointer array
150     impl->numkernels = 4*2*2;
151     ierr = CeedCalloc(impl->numkernels, &impl->kernels); CeedChk(ierr);
152     for (CeedInt nelem = 1; nelem <= 8; nelem+=7)
153       for (CeedInt add = 0; add <= 1; add++)
154         for (CeedInt tmode = 0; tmode <= 1; tmode++)
155           for (CeedInt grad = 1; grad <= impl->dim; grad+=impl->dim-1) {
156             const int flags = LIBXSMM_GEMM_FLAGS('N', tmode ? 'T' : 'N');
157             CeedInt B = tmode ? grad*impl->Q : impl->P,
158                     J = tmode ? impl->P : grad*impl->Q;
159             int ind = CeedGetXsmmInd_NonTensor(add, impl->P, impl->Q, B, nelem,
160                                                J);
161             CeedScalar alpha = 1.0, beta = 1.0;
162             if (!add) beta = 0.0;
163             impl->kernels[ind] = libxsmm_dmmdispatch(nelem, J, B,
164                                  NULL, NULL, NULL, &alpha,
165                                  &beta, &flags, NULL);
166             if (!impl->kernels[ind])
167               // LCOV_EXCL_START
168               return CeedError(ceed, 1, "LIBXSMM kernel failed to build.");
169             // LCOV_EXCL_STOP
170           }
171   }
172   ierr = CeedTensorContractSetData(contract, (void *)&impl); CeedChk(ierr);
173 
174   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply",
175                                 CeedTensorContractApply_Xsmm); CeedChk(ierr);
176   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy",
177                                 CeedTensorContractDestroy_Xsmm); CeedChk(ierr);
178 
179   return 0;
180 }
181