xref: /libCEED/rust/libceed-sys/c-src/backends/cuda-shared/ceed-cuda-shared-basis.c (revision c532df63125d558c3b9bc506378826def3210255)
1*c532df63SYohann // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2*c532df63SYohann // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3*c532df63SYohann // All Rights reserved. See files LICENSE and NOTICE for details.
4*c532df63SYohann //
5*c532df63SYohann // This file is part of CEED, a collection of benchmarks, miniapps, software
6*c532df63SYohann // libraries and APIs for efficient high-order finite element and spectral
7*c532df63SYohann // element discretizations for exascale applications. For more information and
8*c532df63SYohann // source code availability see http://github.com/ceed.
9*c532df63SYohann //
10*c532df63SYohann // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11*c532df63SYohann // a collaborative effort of two U.S. Department of Energy organizations (Office
12*c532df63SYohann // of Science and the National Nuclear Security Administration) responsible for
13*c532df63SYohann // the planning and preparation of a capable exascale ecosystem, including
14*c532df63SYohann // software, applications, hardware, advanced system engineering and early
15*c532df63SYohann // testbed platforms, in support of the nation's exascale computing imperative.
16*c532df63SYohann 
17*c532df63SYohann #include <ceed-backend.h>
18*c532df63SYohann #include <ceed.h>
19*c532df63SYohann #include "ceed-cuda-shared.h"
20*c532df63SYohann #include "../cuda/ceed-cuda.h"
21*c532df63SYohann 
22*c532df63SYohann //*********************
23*c532df63SYohann // shared mem kernels
24*c532df63SYohann static const char *kernelsShared = QUOTE(
25*c532df63SYohann 
26*c532df63SYohann inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
27*c532df63SYohann   for (int i = 0; i < Q1D; i++)
28*c532df63SYohann     r_V[i] += r_U[i];
29*c532df63SYohann }
30*c532df63SYohann 
31*c532df63SYohann //////////
32*c532df63SYohann //  1D  //
33*c532df63SYohann //////////
34*c532df63SYohann 
35*c532df63SYohann inline __device__ void readDofs1d(const int elem, const int tidx,
36*c532df63SYohann                                   const int tidy, const int comp,
37*c532df63SYohann                                   const int nelem, const CeedScalar *d_U, CeedScalar *slice) {
38*c532df63SYohann   for (int i = 0; i < P1D; i++)
39*c532df63SYohann     slice[i] = d_U[i + comp*P1D + elem*BASIS_NCOMP*P1D];
40*c532df63SYohann   for (int i = P1D; i < Q1D; i++)
41*c532df63SYohann     slice[i] = 0.0;
42*c532df63SYohann }
43*c532df63SYohann 
44*c532df63SYohann inline __device__ void writeDofs1d(const int elem, const int tidx,
45*c532df63SYohann                                    const int tidy, const int comp,
46*c532df63SYohann                                    const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
47*c532df63SYohann   if (tidx<P1D) {
48*c532df63SYohann     d_V[tidx + comp*P1D + elem*BASIS_NCOMP*P1D] = r_V;
49*c532df63SYohann   }
50*c532df63SYohann }
51*c532df63SYohann 
52*c532df63SYohann inline __device__ void readQuads1d(const int elem, const int tidx,
53*c532df63SYohann                                    const int tidy, const int comp,
54*c532df63SYohann                                    const int dim, const int nelem, const CeedScalar *d_U, CeedScalar *slice) {
55*c532df63SYohann   for (int i = 0; i < Q1D; i++)
56*c532df63SYohann     slice[i] = d_U[i + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D];
57*c532df63SYohann }
58*c532df63SYohann 
59*c532df63SYohann inline __device__ void writeQuads1d(const int elem, const int tidx,
60*c532df63SYohann                                     const int tidy, const int comp,
61*c532df63SYohann                                     const int dim, const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
62*c532df63SYohann   d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
63*c532df63SYohann }
64*c532df63SYohann 
65*c532df63SYohann inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
66*c532df63SYohann                                    const int tidy,
67*c532df63SYohann                                    const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
68*c532df63SYohann   V = 0.0;
69*c532df63SYohann   for (int i = 0; i < P1D; ++i) {
70*c532df63SYohann     V += B[i + tidx*P1D] * slice[i];//contract x direction
71*c532df63SYohann   }
72*c532df63SYohann }
73*c532df63SYohann 
74*c532df63SYohann inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
75*c532df63SYohann     const int tidy,
76*c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
77*c532df63SYohann   V = 0.0;
78*c532df63SYohann   for (int i = 0; i < Q1D; ++i) {
79*c532df63SYohann     V += B[tidx + i*P1D] * slice[i];//contract x direction
80*c532df63SYohann   }
81*c532df63SYohann }
82*c532df63SYohann 
83*c532df63SYohann inline __device__ void interp1d(const CeedInt nelem, const int transpose,
84*c532df63SYohann                                 const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
85*c532df63SYohann                                 CeedScalar *__restrict__ d_V,
86*c532df63SYohann                                 CeedScalar *slice) {
87*c532df63SYohann   CeedScalar r_V;
88*c532df63SYohann   CeedScalar r_t;
89*c532df63SYohann 
90*c532df63SYohann   const int tidx = threadIdx.x;
91*c532df63SYohann   const int tidy = threadIdx.y;
92*c532df63SYohann 
93*c532df63SYohann 
94*c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
95*c532df63SYohann        elem += gridDim.x*blockDim.z) {
96*c532df63SYohann     for(int comp=0; comp<BASIS_NCOMP; comp++) {
97*c532df63SYohann       if(!transpose) {
98*c532df63SYohann         readDofs1d(elem, tidx, tidy, comp, nelem, d_U, slice);
99*c532df63SYohann         ContractX1d(slice, tidx, tidy, r_t, c_B, r_V);
100*c532df63SYohann         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
101*c532df63SYohann       } else {
102*c532df63SYohann         readQuads1d(elem, tidx, tidy, comp, 0, nelem, d_U, slice);
103*c532df63SYohann         ContractTransposeX1d(slice, tidx, tidy, r_t, c_B, r_V);
104*c532df63SYohann         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
105*c532df63SYohann       }
106*c532df63SYohann     }
107*c532df63SYohann   }
108*c532df63SYohann }
109*c532df63SYohann 
110*c532df63SYohann inline __device__ void grad1d(const CeedInt nelem, const int transpose,
111*c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
112*c532df63SYohann                               const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V,
113*c532df63SYohann                               CeedScalar *slice) {
114*c532df63SYohann   CeedScalar r_U;
115*c532df63SYohann   CeedScalar r_V;
116*c532df63SYohann 
117*c532df63SYohann   const int tidx = threadIdx.x;
118*c532df63SYohann   const int tidy = threadIdx.y;//=>this is really a nb of elements per block
119*c532df63SYohann   int dim;
120*c532df63SYohann 
121*c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
122*c532df63SYohann        elem += gridDim.x*blockDim.z) {
123*c532df63SYohann     for(int comp=0; comp<BASIS_NCOMP; comp++) {
124*c532df63SYohann       if(!transpose) {
125*c532df63SYohann         readDofs1d(elem, tidx, tidy, comp, nelem, d_U, slice);
126*c532df63SYohann         ContractX1d(slice, tidx, tidy, r_U, c_G, r_V);
127*c532df63SYohann         dim = 0;
128*c532df63SYohann         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
129*c532df63SYohann       } else {
130*c532df63SYohann         dim = 0;
131*c532df63SYohann         readQuads1d(elem, tidx, tidy, comp, dim, nelem, d_U, slice);
132*c532df63SYohann         ContractTransposeX1d(slice, tidx, tidy, r_U, c_G, r_V);
133*c532df63SYohann         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
134*c532df63SYohann       }
135*c532df63SYohann     }
136*c532df63SYohann   }
137*c532df63SYohann }
138*c532df63SYohann //////////
139*c532df63SYohann //  2D  //
140*c532df63SYohann //////////
141*c532df63SYohann 
142*c532df63SYohann inline __device__ void readDofs2d(const int elem, const int tidx,
143*c532df63SYohann                                   const int tidy, const int comp,
144*c532df63SYohann                                   const int nelem, const CeedScalar *d_U, CeedScalar &U) {
145*c532df63SYohann   U = (tidx<P1D
146*c532df63SYohann        && tidy<P1D) ? d_U[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D ] :
147*c532df63SYohann       0.0;
148*c532df63SYohann }
149*c532df63SYohann 
150*c532df63SYohann inline __device__ void writeDofs2d(const int elem, const int tidx,
151*c532df63SYohann                                    const int tidy, const int comp,
152*c532df63SYohann                                    const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
153*c532df63SYohann   if (tidx<P1D && tidy<P1D) {
154*c532df63SYohann     d_V[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D ] = r_V;
155*c532df63SYohann   }
156*c532df63SYohann }
157*c532df63SYohann 
158*c532df63SYohann inline __device__ void readQuads2d(const int elem, const int tidx,
159*c532df63SYohann                                    const int tidy, const int comp,
160*c532df63SYohann                                    const int dim, const int nelem, const CeedScalar *d_U, CeedScalar &U ) {
161*c532df63SYohann   U = d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
162*c532df63SYohann                dim*BASIS_NCOMP*nelem*Q1D*Q1D];
163*c532df63SYohann }
164*c532df63SYohann 
165*c532df63SYohann inline __device__ void writeQuads2d(const int elem, const int tidx,
166*c532df63SYohann                                     const int tidy, const int comp,
167*c532df63SYohann                                     const int dim, const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
168*c532df63SYohann   d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
169*c532df63SYohann            dim*BASIS_NCOMP*nelem*Q1D*Q1D ] = r_V;
170*c532df63SYohann }
171*c532df63SYohann 
172*c532df63SYohann inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
173*c532df63SYohann                                    const int tidy,
174*c532df63SYohann                                    const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
175*c532df63SYohann   slice[tidx+tidy*Q1D] = U;
176*c532df63SYohann   __syncthreads();
177*c532df63SYohann   V = 0.0;
178*c532df63SYohann   for (int i = 0; i < P1D; ++i) {
179*c532df63SYohann     V += B[i + tidx*P1D] * slice[i + tidy*Q1D];//contract x direction
180*c532df63SYohann   }
181*c532df63SYohann   __syncthreads();
182*c532df63SYohann }
183*c532df63SYohann 
184*c532df63SYohann inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
185*c532df63SYohann                                    const int tidy,
186*c532df63SYohann                                    const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
187*c532df63SYohann   slice[tidx+tidy*Q1D] = U;
188*c532df63SYohann   __syncthreads();
189*c532df63SYohann   V = 0.0;
190*c532df63SYohann   for (int i = 0; i < P1D; ++i) {
191*c532df63SYohann     V += B[i + tidy*P1D] * slice[tidx + i*Q1D];//contract y direction
192*c532df63SYohann   }
193*c532df63SYohann   __syncthreads();
194*c532df63SYohann }
195*c532df63SYohann 
196*c532df63SYohann inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
197*c532df63SYohann     const int tidy,
198*c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
199*c532df63SYohann   slice[tidx+tidy*Q1D] = U;
200*c532df63SYohann   __syncthreads();
201*c532df63SYohann   V = 0.0;
202*c532df63SYohann   if (tidy<P1D) {
203*c532df63SYohann     for (int i = 0; i < Q1D; ++i) {
204*c532df63SYohann       V += B[tidy + i*P1D] * slice[tidx + i*Q1D];//contract y direction
205*c532df63SYohann     }
206*c532df63SYohann   }
207*c532df63SYohann   __syncthreads();
208*c532df63SYohann }
209*c532df63SYohann 
210*c532df63SYohann inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
211*c532df63SYohann     const int tidy,
212*c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
213*c532df63SYohann   slice[tidx+tidy*Q1D] = U;
214*c532df63SYohann   __syncthreads();
215*c532df63SYohann   V = 0.0;
216*c532df63SYohann   if (tidx<P1D) {
217*c532df63SYohann     for (int i = 0; i < Q1D; ++i) {
218*c532df63SYohann       V += B[tidx + i*P1D] * slice[i + tidy*Q1D];//contract x direction
219*c532df63SYohann     }
220*c532df63SYohann   }
221*c532df63SYohann   __syncthreads();
222*c532df63SYohann }
223*c532df63SYohann 
224*c532df63SYohann inline __device__ void interp2d(const CeedInt nelem, const int transpose,
225*c532df63SYohann                                 const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
226*c532df63SYohann                                 CeedScalar *__restrict__ d_V,
227*c532df63SYohann                                 CeedScalar *slice) {
228*c532df63SYohann   CeedScalar r_V;
229*c532df63SYohann   CeedScalar r_t;
230*c532df63SYohann 
231*c532df63SYohann   const int tidx = threadIdx.x;
232*c532df63SYohann   const int tidy = threadIdx.y;
233*c532df63SYohann 
234*c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
235*c532df63SYohann        elem += gridDim.x*blockDim.z) {
236*c532df63SYohann     for(int comp=0; comp<BASIS_NCOMP; comp++) {
237*c532df63SYohann       r_V = 0.0;
238*c532df63SYohann       r_t = 0.0;
239*c532df63SYohann       if(!transpose) {
240*c532df63SYohann         readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
241*c532df63SYohann         ContractX2d(slice, tidx, tidy, r_V, c_B, r_t);
242*c532df63SYohann         ContractY2d(slice, tidx, tidy, r_t, c_B, r_V);
243*c532df63SYohann         writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
244*c532df63SYohann       } else {
245*c532df63SYohann         readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
246*c532df63SYohann         ContractTransposeY2d(slice, tidx, tidy, r_V, c_B, r_t);
247*c532df63SYohann         ContractTransposeX2d(slice, tidx, tidy, r_t, c_B, r_V);
248*c532df63SYohann         writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
249*c532df63SYohann       }
250*c532df63SYohann     }
251*c532df63SYohann   }
252*c532df63SYohann }
253*c532df63SYohann 
254*c532df63SYohann inline __device__ void grad2d(const CeedInt nelem, const int transpose,
255*c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
256*c532df63SYohann                               const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V,
257*c532df63SYohann                               CeedScalar *slice) {
258*c532df63SYohann   CeedScalar r_U;
259*c532df63SYohann   CeedScalar r_V;
260*c532df63SYohann   CeedScalar r_t;
261*c532df63SYohann 
262*c532df63SYohann   const int tidx = threadIdx.x;
263*c532df63SYohann   const int tidy = threadIdx.y;
264*c532df63SYohann   int dim;
265*c532df63SYohann 
266*c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
267*c532df63SYohann        elem += gridDim.x*blockDim.z) {
268*c532df63SYohann     for(int comp=0; comp<BASIS_NCOMP; comp++) {
269*c532df63SYohann       if(!transpose) {
270*c532df63SYohann         readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
271*c532df63SYohann         ContractX2d(slice, tidx, tidy, r_U, c_G, r_t);
272*c532df63SYohann         ContractY2d(slice, tidx, tidy, r_t, c_B, r_V);
273*c532df63SYohann         dim = 0;
274*c532df63SYohann         writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
275*c532df63SYohann         ContractX2d(slice, tidx, tidy, r_U, c_B, r_t);
276*c532df63SYohann         ContractY2d(slice, tidx, tidy, r_t, c_G, r_V);
277*c532df63SYohann         dim = 1;
278*c532df63SYohann         writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
279*c532df63SYohann       } else {
280*c532df63SYohann         dim = 0;
281*c532df63SYohann         readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
282*c532df63SYohann         ContractTransposeY2d(slice, tidx, tidy, r_U, c_B, r_t);
283*c532df63SYohann         ContractTransposeX2d(slice, tidx, tidy, r_t, c_G, r_V);
284*c532df63SYohann         dim = 1;
285*c532df63SYohann         readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
286*c532df63SYohann         ContractTransposeY2d(slice, tidx, tidy, r_U, c_G, r_t);
287*c532df63SYohann         ContractTransposeX2d(slice, tidx, tidy, r_t, c_B, r_U);
288*c532df63SYohann         r_V+=r_U;
289*c532df63SYohann         writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
290*c532df63SYohann       }
291*c532df63SYohann     }
292*c532df63SYohann   }
293*c532df63SYohann }
294*c532df63SYohann //////////
295*c532df63SYohann //  3D  //
296*c532df63SYohann //////////
297*c532df63SYohann 
298*c532df63SYohann inline __device__ void readDofs3d(const int elem, const int tidx,
299*c532df63SYohann                                   const int tidy, const int comp,
300*c532df63SYohann                                   const int nelem, const CeedScalar *d_U, CeedScalar *r_U) {
301*c532df63SYohann   for (int i = 0; i < P1D; i++)
302*c532df63SYohann     r_U[i] = (tidx<P1D
303*c532df63SYohann               && tidy<P1D) ? d_U[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
304*c532df63SYohann                                       elem*BASIS_NCOMP*P1D*P1D*P1D ] : 0.0;
305*c532df63SYohann   for (int i = P1D; i < Q1D; i++)
306*c532df63SYohann     r_U[i] = 0.0;
307*c532df63SYohann }
308*c532df63SYohann 
309*c532df63SYohann inline __device__ void readQuads3d(const int elem, const int tidx,
310*c532df63SYohann                                    const int tidy, const int comp,
311*c532df63SYohann                                    const int dim, const int nelem, const CeedScalar *d_U, CeedScalar *r_U) {
312*c532df63SYohann   for (int i = 0; i < Q1D; i++)
313*c532df63SYohann     r_U[i] = d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
314*c532df63SYohann                  comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D];
315*c532df63SYohann }
316*c532df63SYohann 
317*c532df63SYohann inline __device__ void writeDofs3d(const int elem, const int tidx,
318*c532df63SYohann                                    const int tidy, const int comp,
319*c532df63SYohann                                    const int nelem, const CeedScalar *r_V, CeedScalar *d_V) {
320*c532df63SYohann   if (tidx<P1D && tidy<P1D) {
321*c532df63SYohann     for (int i = 0; i < P1D; i++)
322*c532df63SYohann       d_V[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
323*c532df63SYohann           elem*BASIS_NCOMP*P1D*P1D*P1D ] = r_V[i];
324*c532df63SYohann   }
325*c532df63SYohann }
326*c532df63SYohann 
327*c532df63SYohann inline __device__ void writeQuads3d(const int elem, const int tidx,
328*c532df63SYohann                                     const int tidy, const int comp,
329*c532df63SYohann                                     const int dim, const int nelem, const CeedScalar *r_V, CeedScalar *d_V) {
330*c532df63SYohann   for (int i = 0; i < Q1D; i++)
331*c532df63SYohann     d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
332*c532df63SYohann         dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D ] = r_V[i];
333*c532df63SYohann }
334*c532df63SYohann 
335*c532df63SYohann inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
336*c532df63SYohann                                    const int tidy,
337*c532df63SYohann                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
338*c532df63SYohann   for (int k = 0; k < P1D; ++k) {
339*c532df63SYohann     slice[tidx+tidy*Q1D] = U[k];
340*c532df63SYohann     __syncthreads();
341*c532df63SYohann     V[k] = 0.0;
342*c532df63SYohann     for (int i = 0; i < P1D; ++i) {
343*c532df63SYohann       V[k] += B[i + tidx*P1D] * slice[i + tidy*Q1D];//contract x direction
344*c532df63SYohann     }
345*c532df63SYohann     __syncthreads();
346*c532df63SYohann   }
347*c532df63SYohann }
348*c532df63SYohann 
349*c532df63SYohann inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
350*c532df63SYohann                                    const int tidy,
351*c532df63SYohann                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
352*c532df63SYohann   for (int k = 0; k < P1D; ++k) {
353*c532df63SYohann     slice[tidx+tidy*Q1D] = U[k];
354*c532df63SYohann     __syncthreads();
355*c532df63SYohann     V[k] = 0.0;
356*c532df63SYohann     for (int i = 0; i < P1D; ++i) {
357*c532df63SYohann       V[k] += B[i + tidy*P1D] * slice[tidx + i*Q1D];//contract y direction
358*c532df63SYohann     }
359*c532df63SYohann     __syncthreads();
360*c532df63SYohann   }
361*c532df63SYohann }
362*c532df63SYohann 
363*c532df63SYohann inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
364*c532df63SYohann                                    const int tidy,
365*c532df63SYohann                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
366*c532df63SYohann   for (int k = 0; k < Q1D; ++k) {
367*c532df63SYohann     V[k] = 0.0;
368*c532df63SYohann     for (int i = 0; i < P1D; ++i) {
369*c532df63SYohann       V[k] += B[i + k*P1D] * U[i];//contract z direction
370*c532df63SYohann     }
371*c532df63SYohann   }
372*c532df63SYohann }
373*c532df63SYohann 
374*c532df63SYohann inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
375*c532df63SYohann     const int tidy,
376*c532df63SYohann     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
377*c532df63SYohann   for (int k = 0; k < Q1D; ++k) {
378*c532df63SYohann     V[k] = 0.0;
379*c532df63SYohann     if (k<P1D) {
380*c532df63SYohann       for (int i = 0; i < Q1D; ++i) {
381*c532df63SYohann         V[k] += B[k + i*P1D] * U[i];//contract z direction
382*c532df63SYohann       }
383*c532df63SYohann     }
384*c532df63SYohann   }
385*c532df63SYohann }
386*c532df63SYohann 
387*c532df63SYohann inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
388*c532df63SYohann     const int tidy,
389*c532df63SYohann     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
390*c532df63SYohann   for (int k = 0; k < P1D; ++k) {
391*c532df63SYohann     slice[tidx+tidy*Q1D] = U[k];
392*c532df63SYohann     __syncthreads();
393*c532df63SYohann     V[k] = 0.0;
394*c532df63SYohann     if (tidy<P1D) {
395*c532df63SYohann       for (int i = 0; i < Q1D; ++i) {
396*c532df63SYohann         V[k] += B[tidy + i*P1D] * slice[tidx + i*Q1D];//contract y direction
397*c532df63SYohann       }
398*c532df63SYohann     }
399*c532df63SYohann     __syncthreads();
400*c532df63SYohann   }
401*c532df63SYohann }
402*c532df63SYohann 
403*c532df63SYohann inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
404*c532df63SYohann     const int tidy,
405*c532df63SYohann     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
406*c532df63SYohann   for (int k = 0; k < P1D; ++k) {
407*c532df63SYohann     slice[tidx+tidy*Q1D] = U[k];
408*c532df63SYohann     __syncthreads();
409*c532df63SYohann     V[k] = 0.0;
410*c532df63SYohann     if (tidx<P1D) {
411*c532df63SYohann       for (int i = 0; i < Q1D; ++i) {
412*c532df63SYohann         V[k] += B[tidx + i*P1D] * slice[i + tidy*Q1D];//contract x direction
413*c532df63SYohann       }
414*c532df63SYohann     }
415*c532df63SYohann     __syncthreads();
416*c532df63SYohann   }
417*c532df63SYohann }
418*c532df63SYohann 
419*c532df63SYohann inline __device__ void interp3d(const CeedInt nelem, const int transpose,
420*c532df63SYohann                                 const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
421*c532df63SYohann                                 CeedScalar *__restrict__ d_V,
422*c532df63SYohann                                 CeedScalar *slice) {
423*c532df63SYohann   CeedScalar r_V[Q1D];
424*c532df63SYohann   CeedScalar r_t[Q1D];
425*c532df63SYohann 
426*c532df63SYohann   const int tidx = threadIdx.x;
427*c532df63SYohann   const int tidy = threadIdx.y;
428*c532df63SYohann 
429*c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
430*c532df63SYohann        elem += gridDim.x*blockDim.z) {
431*c532df63SYohann     for(int comp=0; comp<BASIS_NCOMP; comp++) {
432*c532df63SYohann       for (int i = 0; i < Q1D; ++i) {
433*c532df63SYohann         r_V[i] = 0.0;
434*c532df63SYohann         r_t[i] = 0.0;
435*c532df63SYohann       }
436*c532df63SYohann       if(!transpose) {
437*c532df63SYohann         readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
438*c532df63SYohann         ContractX3d(slice, tidx, tidy, r_V, c_B, r_t);
439*c532df63SYohann         ContractY3d(slice, tidx, tidy, r_t, c_B, r_V);
440*c532df63SYohann         ContractZ3d(slice, tidx, tidy, r_V, c_B, r_t);
441*c532df63SYohann         writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
442*c532df63SYohann       } else {
443*c532df63SYohann         readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
444*c532df63SYohann         ContractTransposeZ3d(slice, tidx, tidy, r_V, c_B, r_t);
445*c532df63SYohann         ContractTransposeY3d(slice, tidx, tidy, r_t, c_B, r_V);
446*c532df63SYohann         ContractTransposeX3d(slice, tidx, tidy, r_V, c_B, r_t);
447*c532df63SYohann         writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
448*c532df63SYohann       }
449*c532df63SYohann     }
450*c532df63SYohann   }
451*c532df63SYohann }
452*c532df63SYohann 
453*c532df63SYohann inline __device__ void grad3d(const CeedInt nelem, const int transpose,
454*c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
455*c532df63SYohann                               const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V,
456*c532df63SYohann                               CeedScalar *slice) {
457*c532df63SYohann   //use P1D for one of these
458*c532df63SYohann   CeedScalar r_U[Q1D];
459*c532df63SYohann   CeedScalar r_V[Q1D];
460*c532df63SYohann   CeedScalar r_t[Q1D];
461*c532df63SYohann 
462*c532df63SYohann   const int tidx = threadIdx.x;
463*c532df63SYohann   const int tidy = threadIdx.y;
464*c532df63SYohann   int dim;
465*c532df63SYohann 
466*c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
467*c532df63SYohann        elem += gridDim.x*blockDim.z) {
468*c532df63SYohann     for(int comp=0; comp<BASIS_NCOMP; comp++) {
469*c532df63SYohann       if(!transpose) {
470*c532df63SYohann         readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
471*c532df63SYohann         ContractX3d(slice, tidx, tidy, r_U, c_G, r_V);
472*c532df63SYohann         ContractY3d(slice, tidx, tidy, r_V, c_B, r_t);
473*c532df63SYohann         ContractZ3d(slice, tidx, tidy, r_t, c_B, r_V);
474*c532df63SYohann         dim = 0;
475*c532df63SYohann         writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
476*c532df63SYohann         ContractX3d(slice, tidx, tidy, r_U, c_B, r_V);
477*c532df63SYohann         ContractY3d(slice, tidx, tidy, r_V, c_G, r_t);
478*c532df63SYohann         ContractZ3d(slice, tidx, tidy, r_t, c_B, r_V);
479*c532df63SYohann         dim = 1;
480*c532df63SYohann         writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
481*c532df63SYohann         ContractX3d(slice, tidx, tidy, r_U, c_B, r_V);
482*c532df63SYohann         ContractY3d(slice, tidx, tidy, r_V, c_B, r_t);
483*c532df63SYohann         ContractZ3d(slice, tidx, tidy, r_t, c_G, r_V);
484*c532df63SYohann         dim = 2;
485*c532df63SYohann         writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
486*c532df63SYohann       } else {
487*c532df63SYohann         dim = 0;
488*c532df63SYohann         readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
489*c532df63SYohann         ContractTransposeZ3d(slice, tidx, tidy, r_U, c_B, r_t);
490*c532df63SYohann         ContractTransposeY3d(slice, tidx, tidy, r_t, c_B, r_U);
491*c532df63SYohann         ContractTransposeX3d(slice, tidx, tidy, r_U, c_G, r_V);
492*c532df63SYohann         dim = 1;
493*c532df63SYohann         readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
494*c532df63SYohann         ContractTransposeZ3d(slice, tidx, tidy, r_U, c_B, r_t);
495*c532df63SYohann         ContractTransposeY3d(slice, tidx, tidy, r_t, c_G, r_U);
496*c532df63SYohann         ContractTransposeX3d(slice, tidx, tidy, r_U, c_B, r_t);
497*c532df63SYohann         add(r_V, r_t);
498*c532df63SYohann         dim = 2;
499*c532df63SYohann         readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
500*c532df63SYohann         ContractTransposeZ3d(slice, tidx, tidy, r_U, c_G, r_t);
501*c532df63SYohann         ContractTransposeY3d(slice, tidx, tidy, r_t, c_B, r_U);
502*c532df63SYohann         ContractTransposeX3d(slice, tidx, tidy, r_U, c_B, r_t);
503*c532df63SYohann         add(r_V, r_t);
504*c532df63SYohann         writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
505*c532df63SYohann       }
506*c532df63SYohann     }
507*c532df63SYohann   }
508*c532df63SYohann }
509*c532df63SYohann 
510*c532df63SYohann /////////////
511*c532df63SYohann // Kernels //
512*c532df63SYohann /////////////
513*c532df63SYohann extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
514*c532df63SYohann                                   const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
515*c532df63SYohann                                   CeedScalar *__restrict__ d_V) {
516*c532df63SYohann   __shared__ double slice[Q1D*Q1D];//Fix me if ElemPerBlock>1
517*c532df63SYohann   if (BASIS_DIM==1) {
518*c532df63SYohann     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
519*c532df63SYohann   } else if (BASIS_DIM==2) {
520*c532df63SYohann     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
521*c532df63SYohann   } else if (BASIS_DIM==3) {
522*c532df63SYohann     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
523*c532df63SYohann   }
524*c532df63SYohann }
525*c532df63SYohann 
526*c532df63SYohann extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
527*c532df63SYohann                                 const CeedScalar *c_B, const CeedScalar *c_G,
528*c532df63SYohann                                 const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) {
529*c532df63SYohann   __shared__ double slice[Q1D*Q1D];//Fix me if ElemPerBlock>1
530*c532df63SYohann   if (BASIS_DIM==1) {
531*c532df63SYohann     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
532*c532df63SYohann   } else if (BASIS_DIM==2) {
533*c532df63SYohann     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
534*c532df63SYohann   } else if (BASIS_DIM==3) {
535*c532df63SYohann     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
536*c532df63SYohann   }
537*c532df63SYohann }
538*c532df63SYohann 
539*c532df63SYohann /////////////
540*c532df63SYohann // Weights //
541*c532df63SYohann /////////////
542*c532df63SYohann __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
543*c532df63SYohann                          CeedScalar *w) {
544*c532df63SYohann   CeedScalar w1d[Q1D];
545*c532df63SYohann   for (int i = 0; i < Q1D; ++i) {
546*c532df63SYohann     w1d[i] = qweight1d[i];
547*c532df63SYohann   }
548*c532df63SYohann   for (int e = blockIdx.x * blockDim.x + threadIdx.x;
549*c532df63SYohann        e < nelem;
550*c532df63SYohann        e += blockDim.x * gridDim.x) {
551*c532df63SYohann     for (int i = 0; i < Q1D; ++i) {
552*c532df63SYohann       const int ind = e*Q1D + i;//sequential
553*c532df63SYohann       w[ind] = w1d[i];
554*c532df63SYohann     }
555*c532df63SYohann   }
556*c532df63SYohann }
557*c532df63SYohann 
558*c532df63SYohann __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
559*c532df63SYohann                          CeedScalar *w) {
560*c532df63SYohann   CeedScalar w1d[Q1D];
561*c532df63SYohann   for (int i = 0; i < Q1D; ++i) {
562*c532df63SYohann     w1d[i] = qweight1d[i];
563*c532df63SYohann   }
564*c532df63SYohann   for (int e = blockIdx.x * blockDim.x + threadIdx.x;
565*c532df63SYohann        e < nelem;
566*c532df63SYohann        e += blockDim.x * gridDim.x) {
567*c532df63SYohann     for (int i = 0; i < Q1D; ++i) {
568*c532df63SYohann       for (int j = 0; j < Q1D; ++j) {
569*c532df63SYohann         const int ind = e*Q1D*Q1D + i + j*Q1D;//sequential
570*c532df63SYohann         w[ind] = w1d[i]*w1d[j];
571*c532df63SYohann       }
572*c532df63SYohann     }
573*c532df63SYohann   }
574*c532df63SYohann }
575*c532df63SYohann 
576*c532df63SYohann __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
577*c532df63SYohann                          CeedScalar *w) {
578*c532df63SYohann   CeedScalar w1d[Q1D];
579*c532df63SYohann   for (int i = 0; i < Q1D; ++i) {
580*c532df63SYohann     w1d[i] = qweight1d[i];
581*c532df63SYohann   }
582*c532df63SYohann   for (int e = blockIdx.x * blockDim.x + threadIdx.x;
583*c532df63SYohann        e < nelem;
584*c532df63SYohann        e += blockDim.x * gridDim.x) {
585*c532df63SYohann     for (int i = 0; i < Q1D; ++i) {
586*c532df63SYohann       for (int j = 0; j < Q1D; ++j) {
587*c532df63SYohann         for (int k = 0; k < Q1D; ++k) {
588*c532df63SYohann           const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;//sequential
589*c532df63SYohann           w[ind] = w1d[i]*w1d[j]*w1d[k];
590*c532df63SYohann         }
591*c532df63SYohann       }
592*c532df63SYohann     }
593*c532df63SYohann   }
594*c532df63SYohann }
595*c532df63SYohann 
596*c532df63SYohann extern "C" __global__ void weight(const CeedInt nelem,
597*c532df63SYohann                                   const CeedScalar *__restrict__ qweight1d, CeedScalar *__restrict__ v) {
598*c532df63SYohann   if (BASIS_DIM==1) {
599*c532df63SYohann     weight1d(nelem, qweight1d, v);
600*c532df63SYohann   } else if (BASIS_DIM==2) {
601*c532df63SYohann     weight2d(nelem, qweight1d, v);
602*c532df63SYohann   } else if (BASIS_DIM==3) {
603*c532df63SYohann     weight3d(nelem, qweight1d, v);
604*c532df63SYohann   }
605*c532df63SYohann }
606*c532df63SYohann 
607*c532df63SYohann                                    );
608*c532df63SYohann 
609*c532df63SYohann int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
610*c532df63SYohann                        CeedScalar **c_B);
611*c532df63SYohann int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
612*c532df63SYohann                            CeedInt Q1d, CeedScalar **c_B_ptr, CeedScalar **c_G_ptr);
613*c532df63SYohann 
614*c532df63SYohann int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem,
615*c532df63SYohann                                      CeedTransposeMode tmode,
616*c532df63SYohann                                      CeedEvalMode emode, CeedVector u, CeedVector v) {
617*c532df63SYohann   int ierr;
618*c532df63SYohann   Ceed ceed;
619*c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
620*c532df63SYohann   Ceed_Cuda_shared *ceed_Cuda;
621*c532df63SYohann   CeedGetData(ceed, (void *) &ceed_Cuda); CeedChk(ierr);
622*c532df63SYohann   CeedBasis_Cuda_shared *data;
623*c532df63SYohann   CeedBasisGetData(basis, (void *)&data); CeedChk(ierr);
624*c532df63SYohann   const CeedInt transpose = tmode == CEED_TRANSPOSE;
625*c532df63SYohann   // const int optElems[7] = {0,32,8,3,2,1,8};
626*c532df63SYohann   int elemsPerBlock = 1;//basis->Q1d < 7 ? optElems[basis->Q1d] : 1;
627*c532df63SYohann   int grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)?
628*c532df63SYohann                                      1 : 0 );
629*c532df63SYohann 
630*c532df63SYohann   const CeedScalar *d_u;
631*c532df63SYohann   CeedScalar *d_v;
632*c532df63SYohann   if(emode!=CEED_EVAL_WEIGHT) {
633*c532df63SYohann     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
634*c532df63SYohann   }
635*c532df63SYohann   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
636*c532df63SYohann 
637*c532df63SYohann   if (tmode == CEED_TRANSPOSE) {
638*c532df63SYohann     CeedInt length;
639*c532df63SYohann     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
640*c532df63SYohann     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
641*c532df63SYohann   }
642*c532df63SYohann   if (emode == CEED_EVAL_INTERP) {
643*c532df63SYohann     //TODO: check performance difference between c_B and d_B
644*c532df63SYohann     CeedInt P1d, Q1d;
645*c532df63SYohann     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
646*c532df63SYohann     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
647*c532df63SYohann     ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
648*c532df63SYohann     CeedChk(ierr);
649*c532df63SYohann     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, &d_u, &d_v};
650*c532df63SYohann     ierr = run_kernel_dim(ceed, data->interp, grid, Q1d, Q1d, elemsPerBlock,
651*c532df63SYohann                           interpargs);
652*c532df63SYohann     CeedChk(ierr);
653*c532df63SYohann   } else if (emode == CEED_EVAL_GRAD) {
654*c532df63SYohann     CeedInt P1d, Q1d;
655*c532df63SYohann     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
656*c532df63SYohann     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
657*c532df63SYohann     ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
658*c532df63SYohann                                   Q1d, &data->c_B, &data->c_G);
659*c532df63SYohann     CeedChk(ierr);
660*c532df63SYohann     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, &data->c_G, &d_u, &d_v};
661*c532df63SYohann     ierr = run_kernel_dim(ceed, data->grad, grid, Q1d, Q1d, elemsPerBlock,
662*c532df63SYohann                           gradargs);
663*c532df63SYohann     CeedChk(ierr);
664*c532df63SYohann   } else if (emode == CEED_EVAL_WEIGHT) {
665*c532df63SYohann     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
666*c532df63SYohann     const int blocksize = 32;
667*c532df63SYohann     int gridsize = nelem/32;
668*c532df63SYohann     if (blocksize * gridsize < nelem)
669*c532df63SYohann       gridsize += 1;
670*c532df63SYohann     ierr = run_kernel(ceed, data->weight, gridsize, blocksize, weightargs);
671*c532df63SYohann   }
672*c532df63SYohann 
673*c532df63SYohann   if(emode!=CEED_EVAL_WEIGHT) {
674*c532df63SYohann     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
675*c532df63SYohann   }
676*c532df63SYohann   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
677*c532df63SYohann 
678*c532df63SYohann   return 0;
679*c532df63SYohann }
680*c532df63SYohann 
681*c532df63SYohann static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
682*c532df63SYohann   int ierr;
683*c532df63SYohann   Ceed ceed;
684*c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
685*c532df63SYohann 
686*c532df63SYohann   CeedBasis_Cuda_shared *data;
687*c532df63SYohann   ierr = CeedBasisGetData(basis, (void *) &data); CeedChk(ierr);
688*c532df63SYohann 
689*c532df63SYohann   CeedChk_Cu(ceed, cuModuleUnload(data->module));
690*c532df63SYohann 
691*c532df63SYohann   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr);
692*c532df63SYohann   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr);
693*c532df63SYohann   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr);
694*c532df63SYohann 
695*c532df63SYohann   ierr = CeedFree(&data); CeedChk(ierr);
696*c532df63SYohann 
697*c532df63SYohann   return 0;
698*c532df63SYohann }
699*c532df63SYohann 
700*c532df63SYohann int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
701*c532df63SYohann                                         const CeedScalar *interp1d,
702*c532df63SYohann                                         const CeedScalar *grad1d,
703*c532df63SYohann                                         const CeedScalar *qref1d,
704*c532df63SYohann                                         const CeedScalar *qweight1d,
705*c532df63SYohann                                         CeedBasis basis) {
706*c532df63SYohann   int ierr;
707*c532df63SYohann   Ceed ceed;
708*c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
709*c532df63SYohann   CeedBasis_Cuda_shared *data;
710*c532df63SYohann   ierr = CeedCalloc(1, &data); CeedChk(ierr);
711*c532df63SYohann 
712*c532df63SYohann   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
713*c532df63SYohann   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr);
714*c532df63SYohann   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
715*c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
716*c532df63SYohann 
717*c532df63SYohann   const CeedInt iBytes = qBytes * P1d;
718*c532df63SYohann   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr);
719*c532df63SYohann   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
720*c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
721*c532df63SYohann 
722*c532df63SYohann   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr);
723*c532df63SYohann   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
724*c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
725*c532df63SYohann 
726*c532df63SYohann   CeedInt ncomp;
727*c532df63SYohann   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
728*c532df63SYohann   ierr = compile(ceed, kernelsShared, &data->module, 7,
729*c532df63SYohann                  "Q1D", Q1d,
730*c532df63SYohann                  "P1D", P1d,
731*c532df63SYohann                  "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
732*c532df63SYohann                      Q1d : P1d, dim),
733*c532df63SYohann                  "BASIS_DIM", dim,
734*c532df63SYohann                  "BASIS_NCOMP", ncomp,
735*c532df63SYohann                  "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
736*c532df63SYohann                  "BASIS_NQPT", CeedIntPow(Q1d, dim)
737*c532df63SYohann                 ); CeedChk(ierr);
738*c532df63SYohann   ierr = get_kernel(ceed, data->module, "interp", &data->interp);
739*c532df63SYohann   CeedChk(ierr);
740*c532df63SYohann   ierr = get_kernel(ceed, data->module, "grad", &data->grad);
741*c532df63SYohann   CeedChk(ierr);
742*c532df63SYohann   ierr = get_kernel(ceed, data->module, "weight", &data->weight);
743*c532df63SYohann   CeedChk(ierr);
744*c532df63SYohann 
745*c532df63SYohann   ierr = CeedBasisSetData(basis, (void *)&data);
746*c532df63SYohann   CeedChk(ierr);
747*c532df63SYohann   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
748*c532df63SYohann                                 CeedBasisApplyTensor_Cuda_shared);
749*c532df63SYohann   CeedChk(ierr);
750*c532df63SYohann   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
751*c532df63SYohann                                 CeedBasisDestroy_Cuda_shared);
752*c532df63SYohann   CeedChk(ierr);
753*c532df63SYohann   return 0;
754*c532df63SYohann }
755*c532df63SYohann 
756*c532df63SYohann int CeedBasisCreateH1_Cuda_shared(CeedElemTopology topo, CeedInt dim,
757*c532df63SYohann                                   CeedInt ndof, CeedInt nqpts,
758*c532df63SYohann                                   const CeedScalar *interp,
759*c532df63SYohann                                   const CeedScalar *grad,
760*c532df63SYohann                                   const CeedScalar *qref,
761*c532df63SYohann                                   const CeedScalar *qweight,
762*c532df63SYohann                                   CeedBasis basis) {
763*c532df63SYohann   int ierr;
764*c532df63SYohann   Ceed ceed;
765*c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
766*c532df63SYohann   return CeedError(ceed, 1, "Backend does not implement generic H1 basis");
767*c532df63SYohann }
768