xref: /libCEED/backends/xsmm/ceed-xsmm-tensor.c (revision 6c5df90db8677641a04ff505ccfa313a57dff4e5)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include "ceed-xsmm.h"
18 
19 // Utility functions for index in pointer array
20 int CeedGetXsmmInd_Tensor(CeedInt nelem, CeedInt add, CeedTransposeMode tmode,
21                           CeedInt B, CeedInt C, CeedInt J, CeedInt currdim,
22                           CeedInt dim) {
23   return (nelem == 8 ? 1:0)*4*2*dim + (add ? 1:0)*4*dim +
24          (tmode ? 1:0)*2*dim + (B == J ? 1:0)*dim + currdim;
25 }
26 
27 int CeedGetXsmmInd_NonTensor(CeedInt add, CeedInt P, CeedInt Q, CeedInt B,
28                              CeedInt C, CeedInt J) {
29   return (C == 8 ? 1:0)*4*2 + (add ? 1:0)*4 +
30          (B == P ? (J == Q ? 0:1) : (B == Q ? 2:3));
31 }
32 
33 // Default Tensor Contact
34 static int CeedTensorContract_Xsmm_C1(CeedTensorContract contract,
35     CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
36     CeedTransposeMode tmode, const CeedInt Add, const CeedScalar *restrict u,
37     CeedScalar *restrict v) {
38   CeedScalar alpha = 1.0, beta = 1.0;
39   char transu = 'N', transt = 'N';
40   if ((tmode == CEED_TRANSPOSE && C != 1)
41       || (tmode == CEED_NOTRANSPOSE && C == 1))
42     transt = 'T';
43 
44   if (!Add)
45     beta = 0.0;
46 
47   // libXSMM GEMM
48   libxsmm_dgemm(&transt, &transu, &J, &A, &B,
49                 &alpha, &t[0], NULL, &u[0], NULL,
50                 &beta, &v[0], NULL);
51 
52   return 0;
53 }
54 
55 // Switch for Tensor Contract
56 static int CeedTensorContractApply_Xsmm(CeedTensorContract contract, CeedInt A,
57                                         CeedInt B, CeedInt C, CeedInt J,
58                                         const CeedScalar *restrict t,
59                                         CeedTransposeMode tmode,
60                                         const CeedInt add,
61                                         const CeedScalar *restrict u,
62                                         CeedScalar *restrict v) {
63   int ierr;
64   CeedInt blksize = 8, ind, nelem;
65   CeedTensorContract_Xsmm *impl;
66   ierr = CeedTensorContractGetData(contract, (void *)&impl); CeedChk(ierr);
67 
68   // Get nelem and current dim
69   CeedScalar currdim = log(C/blksize) / log(J);
70   if (!(C % blksize) && currdim - (int)currdim < 1e-15) {
71     nelem = blksize;
72   } else {
73     nelem = 1;
74     currdim = log(C) / log(J);
75   }
76 
77   // Get kernel index
78   if (impl->tensorbasis)
79     ind = CeedGetXsmmInd_Tensor(nelem, add, tmode==CEED_TRANSPOSE?1:0, B, C, J,
80                                 (CeedInt)currdim, impl->dim);
81   else
82     ind = CeedGetXsmmInd_NonTensor(add, impl->P, impl->Q, B, C, J);
83 
84   // Run kernel or fallback to default implementation
85   if (C != 1)
86     for (CeedInt a=0; a<A; a++)
87       impl->kernels[ind](&u[a*B*C], &t[0], &v[a*J*C], NULL, NULL, NULL);
88   else
89     CeedTensorContract_Xsmm_C1(contract, A, B, C, J, t, tmode, add, u, v);
90 
91   return 0;
92 }
93 
94 static int CeedTensorContractDestroy_Xsmm(CeedTensorContract contract) {
95   int ierr;
96   CeedTensorContract_Xsmm *impl;
97   ierr = CeedTensorContractGetData(contract, (void *)&impl); CeedChk(ierr);
98   ierr = CeedFree(&impl->kernels); CeedChk(ierr);
99   ierr = CeedFree(&impl); CeedChk(ierr);
100 
101   return 0;
102 }
103 
104 int CeedTensorContractCreate_Xsmm(CeedBasis basis,
105                                   CeedTensorContract contract) {
106   int ierr;
107   Ceed ceed;
108   ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChk(ierr);
109   CeedTensorContract_Xsmm *impl;
110   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
111 
112   // Set up pointers to kernels
113   ierr = CeedBasisGetTensorStatus(basis, &impl->tensorbasis); CeedChk(ierr);
114   if (impl->tensorbasis) {
115     ierr = CeedBasisGetNumNodes1D(basis, &impl->P); CeedChk(ierr);
116     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &impl->Q); CeedChk(ierr);
117     ierr = CeedBasisGetDimension(basis, &impl->dim); CeedChk(ierr);
118     // Set up kernel pointer array
119     impl->numkernels = 2*2*4*impl->dim;
120     ierr = CeedCalloc(impl->numkernels, &impl->kernels); CeedChk(ierr);
121     for (CeedInt nelem = 1; nelem <= 8; nelem+=7)
122       for (CeedInt add = 0; add <= 1; add++)
123         for (CeedInt tmode = 0; tmode <= 1; tmode++)
124           for (CeedInt grad = 0; grad <=1; grad++)
125             for (CeedInt dim = 0; dim < impl->dim; dim++) {
126               const int flags = LIBXSMM_GEMM_FLAGS('N', tmode ? 'T' : 'N');
127               CeedInt B = grad ? impl->Q : (tmode ? impl->Q : impl->P),
128                       J = grad ? impl->Q : (tmode ? impl->P : impl->Q),
129                       C = nelem*CeedIntPow(J, dim);
130               int ind = CeedGetXsmmInd_Tensor(nelem, add, tmode, B, C, J, dim,
131                                               impl->dim);
132               CeedScalar alpha = 1.0, beta = 1.0;
133               if (!add) beta = 0.0;
134               impl->kernels[ind] = libxsmm_dmmdispatch(C, J, B,
135                                    NULL, NULL, NULL, &alpha,
136                                    &beta, &flags, NULL);
137               if (!impl->kernels[ind])
138                 // LCOV_EXCL_START
139                 return CeedError(ceed, 1, "LIBXSMM kernel failed to build.");
140               // LCOV_EXCL_STOP
141             }
142   } else {
143     ierr = CeedBasisGetNumNodes(basis, &impl->P); CeedChk(ierr);
144     ierr = CeedBasisGetNumQuadraturePoints(basis, &impl->Q); CeedChk(ierr);
145     ierr = CeedBasisGetDimension(basis, &impl->dim); CeedChk(ierr);
146     // Set up kernel pointer array
147     impl->numkernels = 4*2*2;
148     ierr = CeedCalloc(impl->numkernels, &impl->kernels); CeedChk(ierr);
149     for (CeedInt nelem = 1; nelem <= 8; nelem+=7)
150       for (CeedInt add = 0; add <= 1; add++)
151         for (CeedInt tmode = 0; tmode <= 1; tmode++)
152           for (CeedInt grad = 1; grad <= impl->dim; grad+=impl->dim-1) {
153             const int flags = LIBXSMM_GEMM_FLAGS('N', tmode ? 'T' : 'N');
154             CeedInt B = tmode ? grad*impl->Q : impl->P,
155                     J = tmode ? impl->P : grad*impl->Q;
156             int ind = CeedGetXsmmInd_NonTensor(add, impl->P, impl->Q, B, nelem,
157                                                J);
158             CeedScalar alpha = 1.0, beta = 1.0;
159             if (!add) beta = 0.0;
160             impl->kernels[ind] = libxsmm_dmmdispatch(nelem, J, B,
161                                  NULL, NULL, NULL, &alpha,
162                                  &beta, &flags, NULL);
163             if (!impl->kernels[ind])
164               // LCOV_EXCL_START
165               return CeedError(ceed, 1, "LIBXSMM kernel failed to build.");
166             // LCOV_EXCL_STOP
167           }
168   }
169   ierr = CeedTensorContractSetData(contract, (void *)&impl); CeedChk(ierr);
170 
171   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply",
172                                 CeedTensorContractApply_Xsmm); CeedChk(ierr);
173   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy",
174                                 CeedTensorContractDestroy_Xsmm); CeedChk(ierr);
175 
176   return 0;
177 }
178