xref: /libCEED/rust/libceed-sys/c-src/backends/xsmm/ceed-xsmm-tensor.c (revision d1d35e2f02dc969aee8debf3fd943dd784aa847a)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed/ceed.h>
18 #include <ceed/backend.h>
19 #include <ceed/hash.h>
20 #include <ceed/khash.h>
21 #include <libxsmm.h>
22 #include <stddef.h>
23 #include "ceed-xsmm.h"
24 
25 //------------------------------------------------------------------------------
26 // Tensor Contract C=1
27 //------------------------------------------------------------------------------
28 static int CeedTensorContract_Xsmm_C1(CeedTensorContract contract,
29                                       CeedInt A, CeedInt B, CeedInt C,
30                                       CeedInt J, const CeedScalar *restrict t,
31                                       CeedTransposeMode t_mode,
32                                       const CeedInt add,
33                                       const CeedScalar *restrict u,
34                                       CeedScalar *restrict v) {
35   CeedScalar alpha = 1.0, beta = 1.0;
36   char trans_u = 'N', trans_t = 'N';
37   if ((t_mode == CEED_TRANSPOSE && C != 1)
38       || (t_mode == CEED_NOTRANSPOSE && C == 1))
39     trans_t = 'T';
40 
41   if (!add)
42     beta = 0.0;
43 
44   // libXSMM GEMM
45   libxsmm_dgemm(&trans_t, &trans_u, &J, &A, &B,
46                 &alpha, &t[0], NULL, &u[0], NULL,
47                 &beta, &v[0], NULL);
48 
49   return CEED_ERROR_SUCCESS;
50 }
51 
52 //------------------------------------------------------------------------------
53 // Tensor Contract Apply
54 //------------------------------------------------------------------------------
55 static int CeedTensorContractApply_Xsmm(CeedTensorContract contract, CeedInt A,
56                                         CeedInt B, CeedInt C, CeedInt J,
57                                         const CeedScalar *restrict t,
58                                         CeedTransposeMode t_mode,
59                                         const CeedInt add,
60                                         const CeedScalar *restrict u,
61                                         CeedScalar *restrict v) {
62   int ierr;
63   CeedTensorContract_Xsmm *impl;
64   ierr = CeedTensorContractGetData(contract, &impl); CeedChkBackend(ierr);
65 
66   // Get kernel
67   libxsmm_dmmfunction kernel;
68   CeedHashIJKLMKey key = {B, C, J, t_mode, add};
69   khint_t k = kh_get(m32, impl->lookup, key);
70   CeedHashGetValue(impl->lookup, k, kernel);
71 
72   // Run kernel or fallback to default implementation
73   if (C != 1)
74     for (CeedInt a=0; a<A; a++)
75       kernel(&u[a*B*C], &t[0], &v[a*J*C], NULL, NULL, NULL);
76   else
77     CeedTensorContract_Xsmm_C1(contract, A, B, C, J, t, t_mode, add, u, v);
78 
79   return CEED_ERROR_SUCCESS;
80 }
81 
82 //------------------------------------------------------------------------------
83 // Tensor Contract Destroy
84 //------------------------------------------------------------------------------
85 static int CeedTensorContractDestroy_Xsmm(CeedTensorContract contract) {
86   int ierr;
87   CeedTensorContract_Xsmm *impl;
88   libxsmm_dmmfunction kernel;
89 
90   ierr = CeedTensorContractGetData(contract, &impl); CeedChkBackend(ierr);
91   // Free kernels
92   kh_foreach_value(impl->lookup, kernel, libxsmm_release_kernel(&kernel));
93   kh_destroy(m32, impl->lookup);
94   ierr = CeedFree(&impl); CeedChkBackend(ierr);
95   return CEED_ERROR_SUCCESS;
96 }
97 
98 //------------------------------------------------------------------------------
99 // Tensor Contract Create
100 //------------------------------------------------------------------------------
101 int CeedTensorContractCreate_Xsmm(CeedBasis basis,
102                                   CeedTensorContract contract) {
103   int ierr;
104   Ceed ceed;
105   ierr = CeedTensorContractGetCeed(contract, &ceed); CeedChkBackend(ierr);
106   CeedTensorContract_Xsmm *impl;
107   ierr = CeedCalloc(1, &impl); CeedChkBackend(ierr);
108 
109   // Setup kernels hash table
110   impl->lookup = kh_init(m32);
111 
112   // Set up pointers to kernels
113   ierr = CeedBasisIsTensor(basis, &impl->is_tensor); CeedChkBackend(ierr);
114   if (impl->is_tensor) {
115     ierr = CeedBasisGetNumNodes1D(basis, &impl->P); CeedChkBackend(ierr);
116     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &impl->Q); CeedChkBackend(ierr);
117     ierr = CeedBasisGetDimension(basis, &impl->dim); CeedChkBackend(ierr);
118     // Build all required kernels
119     for (CeedInt num_elem = 1; num_elem <= 8; num_elem+=7)
120       for (CeedInt add = 0; add <= 1; add++)
121         for (CeedInt t_mode = 0; t_mode <= 1; t_mode++)
122           for (CeedInt grad = 0; grad <=1; grad++)
123             for (CeedInt dim = 0; dim < impl->dim; dim++) {
124               const int flags = LIBXSMM_GEMM_FLAGS('N', t_mode ? 'T' : 'N');
125               CeedInt B = grad ? impl->Q : (t_mode ? impl->Q : impl->P),
126                       J = grad ? impl->Q : (t_mode ? impl->P : impl->Q),
127                       C = num_elem*CeedIntPow(J, dim);
128               // Add key, kernel pair to hash table
129               CeedHashIJKLMKey key = {B, C, J, t_mode, add};
130               int new_item;
131               khint_t k = kh_put(m32, impl->lookup, key, &new_item);
132               if (new_item) {
133                 // Build kernel
134                 CeedScalar alpha = 1.0, beta = 1.0;
135                 if (!add) beta = 0.0;
136                 libxsmm_dmmfunction kernel = libxsmm_dmmdispatch(
137                                                C, J, B, NULL, NULL, NULL, &alpha, &beta, &flags, NULL);
138                 if (!kernel)
139                   // LCOV_EXCL_START
140                   return CeedError(ceed, CEED_ERROR_BACKEND, "LIBXSMM kernel failed to build.");
141                 // LCOV_EXCL_STOP
142                 // Add kernel to hash table
143                 kh_value(impl->lookup, k) = kernel;
144               }
145             }
146   } else {
147     ierr = CeedBasisGetNumNodes(basis, &impl->P); CeedChkBackend(ierr);
148     ierr = CeedBasisGetNumQuadraturePoints(basis, &impl->Q); CeedChkBackend(ierr);
149     ierr = CeedBasisGetDimension(basis, &impl->dim); CeedChkBackend(ierr);
150     // Build all required kernels
151     for (CeedInt num_elem = 1; num_elem <= 8; num_elem+=7)
152       for (CeedInt add = 0; add <= 1; add++)
153         for (CeedInt t_mode = 0; t_mode <= 1; t_mode++) {
154           CeedInt gradstride = CeedIntMax(impl->dim-1, 1);
155           for (CeedInt grad = 1; grad <= impl->dim; grad+=gradstride) {
156             const int flags = LIBXSMM_GEMM_FLAGS('N', t_mode ? 'T' : 'N');
157             CeedInt B = t_mode ? grad*impl->Q : impl->P,
158                     J = t_mode ? impl->P : grad*impl->Q,
159                     C = num_elem;
160             // Add key, kernel pair to hash table
161             CeedHashIJKLMKey key = {B, C, J, t_mode, add};
162             int new_item;
163             khint_t k = kh_put(m32, impl->lookup, key, &new_item);
164             if (new_item) {
165               // Build kernel
166               CeedScalar alpha = 1.0, beta = 1.0;
167               if (!add) beta = 0.0;
168               libxsmm_dmmfunction kernel = libxsmm_dmmdispatch(
169                                              C, J, B, NULL, NULL, NULL, &alpha, &beta, &flags, NULL);
170               if (!kernel)
171                 // LCOV_EXCL_START
172                 return CeedError(ceed, CEED_ERROR_BACKEND, "LIBXSMM kernel failed to build.");
173               // LCOV_EXCL_STOP
174               // Add kernel to hash table
175               kh_value(impl->lookup, k) = kernel;
176             }
177           }
178         }
179   }
180   ierr = CeedTensorContractSetData(contract, impl); CeedChkBackend(ierr);
181 
182   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Apply",
183                                 CeedTensorContractApply_Xsmm); CeedChkBackend(ierr);
184   ierr = CeedSetBackendFunction(ceed, "TensorContract", contract, "Destroy",
185                                 CeedTensorContractDestroy_Xsmm); CeedChkBackend(ierr);
186 
187   return CEED_ERROR_SUCCESS;
188 }
189 //------------------------------------------------------------------------------
190