xref: /libCEED/backends/xsmm/ceed-xsmm-tensor.c (revision b3d368a0e19d4c2fd0dbe42e8214dac7e42dd106)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include "ceed-xsmm.h"
18 
19 //------------------------------------------------------------------------------
20 // Get Kernel Index
21 //------------------------------------------------------------------------------
22 int CeedGetXsmmInd_Tensor(CeedInt nelem, CeedInt add, CeedTransposeMode tmode,
23                           CeedInt B, CeedInt C, CeedInt J, CeedInt currdim,
24                           CeedInt dim) {
25   return (nelem == 8 ? 1:0)*4*2*dim + (add ? 1:0)*4*dim +
26          (tmode ? 1:0)*2*dim + (B == J ? 1:0)*dim + currdim;
27 }
28 
29 int CeedGetXsmmInd_NonTensor(CeedInt add, CeedInt P, CeedInt Q, CeedInt B,
30                              CeedInt C, CeedInt J) {
31   return (C == 8 ? 1:0)*4*2 + (add ? 1:0)*4 +
32          (B == P ? (J == Q ? 0:1) : (B == Q ? 2:3));
33 }
34 
35 //------------------------------------------------------------------------------
36 // Tensor Contract C=1
37 //------------------------------------------------------------------------------
38 static int CeedTensorContract_Xsmm_C1(CeedTensorContract contract,
39                                       CeedInt A, CeedInt B, CeedInt C,
40                                       CeedInt J, const CeedScalar *restrict t,
41                                       CeedTransposeMode tmode,
42                                       const CeedInt Add,
43                                       const CeedScalar *restrict u,
44                                       CeedScalar *restrict v) {
45   CeedScalar alpha = 1.0, beta = 1.0;
46   char transu = 'N', transt = 'N';
47   if ((tmode == CEED_TRANSPOSE && C != 1)
48       || (tmode == CEED_NOTRANSPOSE && C == 1))
49     transt = 'T';
50 
51   if (!Add)
52     beta = 0.0;
53 
54   // libXSMM GEMM
55   libxsmm_dgemm(&transt, &transu, &J, &A, &B,
56                 &alpha, &t[0], NULL, &u[0], NULL,
57                 &beta, &v[0], NULL);
58 
59   return 0;
60 }
61 
62 //------------------------------------------------------------------------------
63 // Tensor Contract Apply
64 //------------------------------------------------------------------------------
65 static int CeedTensorContractApply_Xsmm(CeedTensorContract contract, CeedInt A,
66                                         CeedInt B, CeedInt C, CeedInt J,
67                                         const CeedScalar *restrict t,
68                                         CeedTransposeMode tmode,
69                                         const CeedInt add,
70                                         const CeedScalar *restrict u,
71                                         CeedScalar *restrict v) {
72   int ierr;
73   CeedInt blksize = 8, ind, nelem;
74   CeedTensorContract_Xsmm *impl;
75   ierr = CeedTensorContractGetData(contract, (void *)&impl); CeedChk(ierr);
76 
77   // Get nelem and current dim
78   CeedScalar currdim = log(C/blksize) / log(J);
79   if (!(C % blksize) && currdim - (int)currdim < 1e-15) {
80     nelem = blksize;
81   } else {
82     nelem = 1;
83     currdim = log(C) / log(J);
84   }
85 
86   // Get kernel index
87   if (impl->tensorbasis)
88     ind = CeedGetXsmmInd_Tensor(nelem, add, tmode==CEED_TRANSPOSE?1:0, B, C, J,
89                                 (CeedInt)currdim, impl->dim);
90   else
91     ind = CeedGetXsmmInd_NonTensor(add, impl->P, impl->Q, B, C, J);
92 
93   // Run kernel or fallback to default implementation
94   if (C != 1)
95     for (CeedInt a=0; a<A; a++)
96       impl->kernels[ind](&u[a*B*C], &t[0], &v[a*J*C], NULL, NULL, NULL);
97   else
98     CeedTensorContract_Xsmm_C1(contract, A, B, C, J, t, tmode, add, u, v);
99 
100   return 0;
101 }
102 
103 //------------------------------------------------------------------------------
104 // Tensor Contract Destroy
105 //------------------------------------------------------------------------------
106 static int CeedTensorContractDestroy_Xsmm(CeedTensorContract contract) {
107   int ierr;
108   CeedTensorContract_Xsmm *impl;
109   ierr = CeedTensorContractGetData(contract, (void *)&impl); CeedChk(ierr);
110   // Free kernels
111   if (impl->tensorbasis) {
112     for (CeedInt nelem = 1; nelem <= 8; nelem+=7)
113       for (CeedInt add = 0; add <= 1; add++)
114         for (CeedInt tmode = 0; tmode <= 1; tmode++)
115           for (CeedInt grad = 0; grad <=1; grad++)
116             for (CeedInt dim = 0; dim < impl->dim; dim++) {
117               CeedInt B = grad ? impl->Q : (tmode ? impl->Q : impl->P),
118                       J = grad ? impl->Q : (tmode ? impl->P : impl->Q),
119                       C = nelem*CeedIntPow(J, dim);
120               int ind = CeedGetXsmmInd_Tensor(nelem, add, tmode, B, C, J, dim,
121                                               impl->dim);
122               libxsmm_release_kernel(&impl->kernels[ind]);
123             }
124   } else {
125     for (CeedInt nelem = 1; nelem <= 8; nelem+=7)
126       for (CeedInt add = 0; add <= 1; add++)
127         for (CeedInt tmode = 0; tmode <= 1; tmode++)
128           for (CeedInt grad = 1; grad <= impl->dim; grad+=impl->dim-1) {
129             CeedInt B = tmode ? grad*impl->Q : impl->P,
130                     J = tmode ? impl->P : grad*impl->Q;
131             int ind = CeedGetXsmmInd_NonTensor(add, impl->P, impl->Q, B, nelem,
132                                                J);
133             libxsmm_release_kernel(&impl->kernels[ind]);
134           }
135   }
136   ierr = CeedFree(&impl->kernels); CeedChk(ierr);
137   ierr = CeedFree(&impl); CeedChk(ierr);
138 
139   return 0;
140 }
141 
142 //------------------------------------------------------------------------------
143 // Tensor Contract Create
144 //------------------------------------------------------------------------------
145 int CeedTensorContractCreate_Xsmm(CeedBasis basis,
146                                   CeedTensorContract contract) {
147   int ierr;
148   Ceed ceed;
149   ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChk(ierr);
150   CeedTensorContract_Xsmm *impl;
151   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
152 
153   // Set up pointers to kernels
154   ierr = CeedBasisGetTensorStatus(basis, &impl->tensorbasis); CeedChk(ierr);
155   if (impl->tensorbasis) {
156     ierr = CeedBasisGetNumNodes1D(basis, &impl->P); CeedChk(ierr);
157     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &impl->Q); CeedChk(ierr);
158     ierr = CeedBasisGetDimension(basis, &impl->dim); CeedChk(ierr);
159     // Set up kernel pointer array
160     impl->numkernels = 2*2*4*impl->dim;
161     ierr = CeedCalloc(impl->numkernels, &impl->kernels); CeedChk(ierr);
162     for (CeedInt nelem = 1; nelem <= 8; nelem+=7)
163       for (CeedInt add = 0; add <= 1; add++)
164         for (CeedInt tmode = 0; tmode <= 1; tmode++)
165           for (CeedInt grad = 0; grad <=1; grad++)
166             for (CeedInt dim = 0; dim < impl->dim; dim++) {
167               const int flags = LIBXSMM_GEMM_FLAGS('N', tmode ? 'T' : 'N');
168               CeedInt B = grad ? impl->Q : (tmode ? impl->Q : impl->P),
169                       J = grad ? impl->Q : (tmode ? impl->P : impl->Q),
170                       C = nelem*CeedIntPow(J, dim);
171               int ind = CeedGetXsmmInd_Tensor(nelem, add, tmode, B, C, J, dim,
172                                               impl->dim);
173               CeedScalar alpha = 1.0, beta = 1.0;
174               if (!add) beta = 0.0;
175               impl->kernels[ind] = libxsmm_dmmdispatch(C, J, B,
176                                    NULL, NULL, NULL, &alpha,
177                                    &beta, &flags, NULL);
178               if (!impl->kernels[ind])
179                 // LCOV_EXCL_START
180                 return CeedError(ceed, 1, "LIBXSMM kernel failed to build.");
181               // LCOV_EXCL_STOP
182             }
183   } else {
184     ierr = CeedBasisGetNumNodes(basis, &impl->P); CeedChk(ierr);
185     ierr = CeedBasisGetNumQuadraturePoints(basis, &impl->Q); CeedChk(ierr);
186     ierr = CeedBasisGetDimension(basis, &impl->dim); CeedChk(ierr);
187     // Set up kernel pointer array
188     impl->numkernels = 4*2*2;
189     ierr = CeedCalloc(impl->numkernels, &impl->kernels); CeedChk(ierr);
190     for (CeedInt nelem = 1; nelem <= 8; nelem+=7)
191       for (CeedInt add = 0; add <= 1; add++)
192         for (CeedInt tmode = 0; tmode <= 1; tmode++)
193           for (CeedInt grad = 1; grad <= impl->dim; grad+=impl->dim-1) {
194             const int flags = LIBXSMM_GEMM_FLAGS('N', tmode ? 'T' : 'N');
195             CeedInt B = tmode ? grad*impl->Q : impl->P,
196                     J = tmode ? impl->P : grad*impl->Q;
197             int ind = CeedGetXsmmInd_NonTensor(add, impl->P, impl->Q, B, nelem,
198                                                J);
199             CeedScalar alpha = 1.0, beta = 1.0;
200             if (!add) beta = 0.0;
201             impl->kernels[ind] = libxsmm_dmmdispatch(nelem, J, B,
202                                  NULL, NULL, NULL, &alpha,
203                                  &beta, &flags, NULL);
204             if (!impl->kernels[ind])
205               // LCOV_EXCL_START
206               return CeedError(ceed, 1, "LIBXSMM kernel failed to build.");
207             // LCOV_EXCL_STOP
208           }
209   }
210   ierr = CeedTensorContractSetData(contract, (void *)&impl); CeedChk(ierr);
211 
212   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply",
213                                 CeedTensorContractApply_Xsmm); CeedChk(ierr);
214   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy",
215                                 CeedTensorContractDestroy_Xsmm); CeedChk(ierr);
216 
217   return 0;
218 }
219 //------------------------------------------------------------------------------
220