xref: /libCEED/backends/xsmm/ceed-xsmm-tensor.c (revision 89c6efa49859cd40e0a2f99b4a91b922f3b4c9b2)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include "ceed-xsmm.h"
18 
19 // Utility functions for index in pointer array
20 int CeedGetXsmmInd_Tensor(CeedInt nelem, CeedInt add, CeedTransposeMode tmode,
21                           CeedInt B, CeedInt C, CeedInt J, CeedInt currdim,
22                           CeedInt dim) {
23   return (nelem == 8 ? 1:0)*4*2*dim + (add ? 1:0)*4*dim +
24          (tmode ? 1:0)*2*dim + (B == J ? 1:0)*dim + currdim;
25 }
26 
27 int CeedGetXsmmInd_NonTensor(CeedInt add, CeedInt P, CeedInt Q, CeedInt B,
28                              CeedInt C, CeedInt J) {
29   return (C == 8 ? 1:0)*4*2 + (add ? 1:0)*4 +
30          (B == P ? (J == Q ? 0:1) : (B == Q ? 2:3));
31 }
32 
33 // Default Tensor Contact
34 static int CeedTensorContract_Xsmm_Default(CeedTensorContract contract,
35     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
36     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
37     CeedScalar *restrict v) {
38   CeedScalar alpha = 1.0, beta = 1.0;
39   char transu = 'N', transt = 'N';
40   if ((tmode == CEED_TRANSPOSE && C != 1)
41       || (tmode == CEED_NOTRANSPOSE && C == 1))
42     transt = 'T';
43 
44   if (!Add)
45     beta = 0.0;
46 
47   if (C != 1)
48     for (CeedInt a=0; a<A; a++)
49       // libXSMM GEMM
50       libxsmm_dgemm(&transu, &transt, &C, &J, &B,
51                     &alpha, &u[a*B*C], NULL, &t[0], NULL,
52                     &beta, &v[a*J*C], NULL);
53   else
54     // libXSMM GEMM
55     libxsmm_dgemm(&transt, &transu, &J, &A, &B,
56                   &alpha, &t[0], NULL, &u[0], NULL,
57                   &beta, &v[0], NULL);
58 
59   return 0;
60 }
61 
62 // Switch for Tensor Contract
63 static int CeedTensorContractApply_Xsmm(CeedTensorContract contract, CeedInt A,
64                                         CeedInt B, CeedInt C, CeedInt J,
65                                         const CeedScalar *restrict t,
66                                         CeedTransposeMode tmode,
67                                         const CeedInt add,
68                                         const CeedScalar *restrict u,
69                                         CeedScalar *restrict v) {
70   int ierr;
71   CeedInt blksize = 8, ind, nelem;
72   CeedTensorContract_Xsmm *impl;
73   ierr = CeedTensorContractGetData(contract, (void *)&impl); CeedChk(ierr);
74 
75   // Get nelem and current dim
76   CeedScalar currdim = log(C/blksize) / log(J);
77   if (!(C % blksize) && currdim - (int)currdim < 1e-15)
78     nelem = blksize;
79   else {
80     nelem = 1;
81     currdim = log(C) / log(J);
82   }
83 
84   // Get kernel index
85   if (impl->tensorbasis)
86     ind = CeedGetXsmmInd_Tensor(nelem, add, tmode==CEED_TRANSPOSE?1:0, B, C,
87                                 J, (CeedInt)currdim, impl->dim);
88   else
89     ind = CeedGetXsmmInd_NonTensor(add, impl->P, impl->Q, B, C, J);
90 
91   // Run kernel or fallback to default implementation
92   if (C != 1 && impl->kernels[ind])
93     for (CeedInt a=0; a<A; a++)
94       impl->kernels[ind](&u[a*B*C], &t[0], &v[a*J*C], NULL, NULL, NULL);
95   else
96     CeedTensorContract_Xsmm_Default(contract, A, B, C, J, t, tmode, add, u, v);
97 
98   return 0;
99 }
100 
101 static int CeedTensorContractDestroy_Xsmm(CeedTensorContract contract) {
102   int ierr;
103   CeedTensorContract_Xsmm *impl;
104   ierr = CeedTensorContractGetData(contract, (void *)&impl); CeedChk(ierr);
105   ierr = CeedFree(&impl->kernels); CeedChk(ierr);
106   ierr = CeedFree(&impl); CeedChk(ierr);
107 
108   return 0;
109 }
110 
111 int CeedTensorContractCreate_Xsmm(CeedBasis basis,
112                                   CeedTensorContract contract) {
113   int ierr;
114   Ceed ceed;
115   ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChk(ierr);
116   CeedTensorContract_Xsmm *impl;
117   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
118 
119   // Set up pointers to kernels
120   ierr = CeedBasisGetTensorStatus(basis, &impl->tensorbasis); CeedChk(ierr);
121   if (impl->tensorbasis) {
122     ierr = CeedBasisGetNumNodes1D(basis, &impl->P); CeedChk(ierr);
123     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &impl->Q); CeedChk(ierr);
124     ierr = CeedBasisGetDimension(basis, &impl->dim); CeedChk(ierr);
125     // Set up kernel pointer array
126     impl->numkernels = 2*2*4*impl->dim;
127     ierr = CeedCalloc(impl->numkernels, &impl->kernels); CeedChk(ierr);
128     for (CeedInt nelem = 1; nelem <= 8; nelem+=7) {
129       for (CeedInt add = 0; add <= 1; add++) {
130         for (CeedInt tmode = 0; tmode <= 1; tmode++) {
131           for (CeedInt grad = 0; grad <=1; grad++) {
132             for (CeedInt dim = 0; dim < impl->dim; dim++) {
133               const int flags = LIBXSMM_GEMM_FLAGS('N', tmode ? 'T' : 'N');
134               CeedInt B = grad ? impl->Q : (tmode ? impl->Q : impl->P),
135                       J = grad ? impl->Q : (tmode ? impl->P : impl->Q),
136                       C = nelem*CeedIntPow(J, dim);
137               int ind = CeedGetXsmmInd_Tensor(nelem, add, tmode, B, C, J, dim,
138                                               impl->dim);
139               CeedScalar alpha = 1.0, beta = 1.0;
140               if (!add) beta = 0.0;
141               impl->kernels[ind] = libxsmm_dmmdispatch(C, J, B,
142                                    NULL, NULL, NULL, &alpha,
143                                    &beta, &flags, NULL);
144               if (!impl->kernels[ind])
145                 return CeedError(ceed, 1,
146                                  "LIBXSMM kernel failed to build.");
147             }
148           }
149         }
150       }
151     }
152   } else {
153     ierr = CeedBasisGetNumNodes(basis, &impl->P); CeedChk(ierr);
154     ierr = CeedBasisGetNumQuadraturePoints(basis, &impl->Q); CeedChk(ierr);
155     ierr = CeedBasisGetDimension(basis, &impl->dim); CeedChk(ierr);
156     // Set up kernel pointer array
157     impl->numkernels = 4*2*2;
158     ierr = CeedCalloc(impl->numkernels, &impl->kernels); CeedChk(ierr);
159     for (CeedInt nelem = 1; nelem <= 8; nelem+=7) {
160       for (CeedInt add = 0; add <= 1; add++) {
161         for (CeedInt tmode = 0; tmode <= 1; tmode++) {
162           for (CeedInt grad = 1; grad <= impl->dim; grad+=impl->dim-1) {
163             const int flags = LIBXSMM_GEMM_FLAGS('N', tmode ? 'T' : 'N');
164             CeedInt B = tmode ? grad*impl->Q : impl->P,
165                     J = tmode ? impl->P : grad*impl->Q;
166             int ind = CeedGetXsmmInd_NonTensor(add, impl->P, impl->Q, B, nelem,
167                                                J);
168             CeedScalar alpha = 1.0, beta = 1.0;
169             if (!add) beta = 0.0;
170             impl->kernels[ind] = libxsmm_dmmdispatch(nelem, J, B,
171                                  NULL, NULL, NULL, &alpha,
172                                  &beta, &flags, NULL);
173             if (!impl->kernels[ind])
174               return CeedError(ceed, 1,
175                                "LIBXSMM kernel failed to build.");
176           }
177         }
178       }
179     }
180   }
181   ierr = CeedTensorContractSetData(contract, (void *)&impl); CeedChk(ierr);
182 
183   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply",
184                                 CeedTensorContractApply_Xsmm); CeedChk(ierr);
185   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy",
186                                 CeedTensorContractDestroy_Xsmm); CeedChk(ierr);
187 
188   return 0;
189 }
190