xref: /libCEED/rust/libceed-sys/c-src/backends/xsmm/ceed-xsmm-tensor.c (revision d99fa3c5cd91a1690aedf0679cbf290d44fec74c)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include "ceed-xsmm.h"
18 
19 //------------------------------------------------------------------------------
20 // Tensor Contract C=1
21 //------------------------------------------------------------------------------
22 static int CeedTensorContract_Xsmm_C1(CeedTensorContract contract,
23                                       CeedInt A, CeedInt B, CeedInt C,
24                                       CeedInt J, const CeedScalar *restrict t,
25                                       CeedTransposeMode tmode,
26                                       const CeedInt add,
27                                       const CeedScalar *restrict u,
28                                       CeedScalar *restrict v) {
29   CeedScalar alpha = 1.0, beta = 1.0;
30   char transu = 'N', transt = 'N';
31   if ((tmode == CEED_TRANSPOSE && C != 1)
32       || (tmode == CEED_NOTRANSPOSE && C == 1))
33     transt = 'T';
34 
35   if (!add)
36     beta = 0.0;
37 
38   // libXSMM GEMM
39   libxsmm_dgemm(&transt, &transu, &J, &A, &B,
40                 &alpha, &t[0], NULL, &u[0], NULL,
41                 &beta, &v[0], NULL);
42 
43   return 0;
44 }
45 
46 //------------------------------------------------------------------------------
47 // Tensor Contract Apply
48 //------------------------------------------------------------------------------
49 static int CeedTensorContractApply_Xsmm(CeedTensorContract contract, CeedInt A,
50                                         CeedInt B, CeedInt C, CeedInt J,
51                                         const CeedScalar *restrict t,
52                                         CeedTransposeMode tmode,
53                                         const CeedInt add,
54                                         const CeedScalar *restrict u,
55                                         CeedScalar *restrict v) {
56   int ierr;
57   CeedTensorContract_Xsmm *impl;
58   ierr = CeedTensorContractGetData(contract, (void *)&impl); CeedChk(ierr);
59 
60   // Get kernel
61   libxsmm_dmmfunction kernel;
62   CeedHashIJKLMKey key = {B, C, J, tmode, add};
63   khint_t k = kh_get(m32, impl->lookup, key);
64   CeedHashGetValue(impl->lookup, k, kernel);
65 
66   // Run kernel or fallback to default implementation
67   if (C != 1)
68     for (CeedInt a=0; a<A; a++)
69       kernel(&u[a*B*C], &t[0], &v[a*J*C], NULL, NULL, NULL);
70   else
71     CeedTensorContract_Xsmm_C1(contract, A, B, C, J, t, tmode, add, u, v);
72 
73   return 0;
74 }
75 
76 //------------------------------------------------------------------------------
77 // Tensor Contract Destroy
78 //------------------------------------------------------------------------------
79 static int CeedTensorContractDestroy_Xsmm(CeedTensorContract contract) {
80   int ierr;
81   CeedTensorContract_Xsmm *impl;
82   libxsmm_dmmfunction kernel;
83 
84   ierr = CeedTensorContractGetData(contract, (void *)&impl); CeedChk(ierr);
85   // Free kernels
86   kh_foreach_value(impl->lookup, kernel, libxsmm_release_kernel(&kernel));
87   kh_destroy(m32, impl->lookup);
88   ierr = CeedFree(&impl); CeedChk(ierr);
89   return 0;
90 }
91 
92 //------------------------------------------------------------------------------
93 // Tensor Contract Create
94 //------------------------------------------------------------------------------
95 int CeedTensorContractCreate_Xsmm(CeedBasis basis,
96                                   CeedTensorContract contract) {
97   int ierr;
98   Ceed ceed;
99   ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChk(ierr);
100   CeedTensorContract_Xsmm *impl;
101   ierr = CeedCalloc(1, &impl); CeedChk(ierr);
102 
103   // Setup kernels hash table
104   impl->lookup = kh_init(m32);
105 
106   // Set up pointers to kernels
107   ierr = CeedBasisIsTensor(basis, &impl->isTensor); CeedChk(ierr);
108   if (impl->isTensor) {
109     ierr = CeedBasisGetNumNodes1D(basis, &impl->P); CeedChk(ierr);
110     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &impl->Q); CeedChk(ierr);
111     ierr = CeedBasisGetDimension(basis, &impl->dim); CeedChk(ierr);
112     // Build all required kernels
113     for (CeedInt nelem = 1; nelem <= 8; nelem+=7)
114       for (CeedInt add = 0; add <= 1; add++)
115         for (CeedInt tmode = 0; tmode <= 1; tmode++)
116           for (CeedInt grad = 0; grad <=1; grad++)
117             for (CeedInt dim = 0; dim < impl->dim; dim++) {
118               const int flags = LIBXSMM_GEMM_FLAGS('N', tmode ? 'T' : 'N');
119               CeedInt B = grad ? impl->Q : (tmode ? impl->Q : impl->P),
120                       J = grad ? impl->Q : (tmode ? impl->P : impl->Q),
121                       C = nelem*CeedIntPow(J, dim);
122               // Add key, kernel pair to hash table
123               CeedHashIJKLMKey key = {B, C, J, tmode, add};
124               int new_item;
125               khint_t k = kh_put(m32, impl->lookup, key, &new_item);
126               if (new_item) {
127                 // Build kernel
128                 CeedScalar alpha = 1.0, beta = 1.0;
129                 if (!add) beta = 0.0;
130                 libxsmm_dmmfunction kernel = libxsmm_dmmdispatch(
131                                                C, J, B, NULL, NULL, NULL, &alpha, &beta, &flags, NULL);
132                 if (!kernel)
133                   // LCOV_EXCL_START
134                   return CeedError(ceed, 1, "LIBXSMM kernel failed to build.");
135                 // LCOV_EXCL_STOP
136                 // Add kernel to hash table
137                 kh_value(impl->lookup, k) = kernel;
138               }
139             }
140   } else {
141     ierr = CeedBasisGetNumNodes(basis, &impl->P); CeedChk(ierr);
142     ierr = CeedBasisGetNumQuadraturePoints(basis, &impl->Q); CeedChk(ierr);
143     ierr = CeedBasisGetDimension(basis, &impl->dim); CeedChk(ierr);
144     // Build all required kernels
145     for (CeedInt nelem = 1; nelem <= 8; nelem+=7)
146       for (CeedInt add = 0; add <= 1; add++)
147         for (CeedInt tmode = 0; tmode <= 1; tmode++) {
148           CeedInt gradstride = CeedIntMax(impl->dim-1, 1);
149           for (CeedInt grad = 1; grad <= impl->dim; grad+=gradstride) {
150             const int flags = LIBXSMM_GEMM_FLAGS('N', tmode ? 'T' : 'N');
151             CeedInt B = tmode ? grad*impl->Q : impl->P,
152                     J = tmode ? impl->P : grad*impl->Q,
153                     C = nelem;
154             // Add key, kernel pair to hash table
155             CeedHashIJKLMKey key = {B, C, J, tmode, add};
156             int new_item;
157             khint_t k = kh_put(m32, impl->lookup, key, &new_item);
158             if (new_item) {
159               // Build kernel
160               CeedScalar alpha = 1.0, beta = 1.0;
161               if (!add) beta = 0.0;
162               libxsmm_dmmfunction kernel = libxsmm_dmmdispatch(
163                                              C, J, B, NULL, NULL, NULL, &alpha, &beta, &flags, NULL);
164               if (!kernel)
165                 // LCOV_EXCL_START
166                 return CeedError(ceed, 1, "LIBXSMM kernel failed to build.");
167               // LCOV_EXCL_STOP
168               // Add kernel to hash table
169               kh_value(impl->lookup, k) = kernel;
170             }
171           }
172         }
173   }
174   ierr = CeedTensorContractSetData(contract, (void *)&impl); CeedChk(ierr);
175 
176   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply",
177                                 CeedTensorContractApply_Xsmm); CeedChk(ierr);
178   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy",
179                                 CeedTensorContractDestroy_Xsmm); CeedChk(ierr);
180 
181   return 0;
182 }
183 //------------------------------------------------------------------------------
184