xref: /libCEED/rust/libceed-sys/c-src/backends/cuda-shared/ceed-cuda-shared-basis.c (revision ab213215e569729a95fd10d21627f8060eaeb868)
1c532df63SYohann // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2c532df63SYohann // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3c532df63SYohann // All Rights reserved. See files LICENSE and NOTICE for details.
4c532df63SYohann //
5c532df63SYohann // This file is part of CEED, a collection of benchmarks, miniapps, software
6c532df63SYohann // libraries and APIs for efficient high-order finite element and spectral
7c532df63SYohann // element discretizations for exascale applications. For more information and
8c532df63SYohann // source code availability see http://github.com/ceed.
9c532df63SYohann //
10c532df63SYohann // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11c532df63SYohann // a collaborative effort of two U.S. Department of Energy organizations (Office
12c532df63SYohann // of Science and the National Nuclear Security Administration) responsible for
13c532df63SYohann // the planning and preparation of a capable exascale ecosystem, including
14c532df63SYohann // software, applications, hardware, advanced system engineering and early
15c532df63SYohann // testbed platforms, in support of the nation's exascale computing imperative.
16c532df63SYohann 
17c532df63SYohann #include <ceed-backend.h>
18c532df63SYohann #include <ceed.h>
19c532df63SYohann #include "ceed-cuda-shared.h"
20c532df63SYohann #include "../cuda/ceed-cuda.h"
21c532df63SYohann 
22*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
23*ab213215SJeremy L Thompson // Shared mem kernels
24*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
25cb0b5415Sjeremylt // *INDENT-OFF*
26c532df63SYohann static const char *kernelsShared = QUOTE(
27c532df63SYohann 
28*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
29*ab213215SJeremy L Thompson // Sum input into output
30*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
31c532df63SYohann inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
32c532df63SYohann   for (int i = 0; i < Q1D; i++)
33c532df63SYohann     r_V[i] += r_U[i];
34c532df63SYohann }
35c532df63SYohann 
36*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
37*ab213215SJeremy L Thompson // 1D
38*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
39c532df63SYohann 
40*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
41*ab213215SJeremy L Thompson // Read DoFs
42*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
43c532df63SYohann inline __device__ void readDofs1d(const int elem, const int tidx,
44d94769d2SYohann Dudouit                                   const int tidy, const int tidz,const int comp,
457f823360Sjeremylt                                   const int nelem, const CeedScalar *d_U,
467f823360Sjeremylt                                   CeedScalar *slice) {
47c532df63SYohann   for (int i = 0; i < P1D; i++)
48d94769d2SYohann Dudouit     slice[i + tidz*Q1D] = d_U[i + comp*P1D + elem*BASIS_NCOMP*P1D];
49c532df63SYohann   for (int i = P1D; i < Q1D; i++)
50d94769d2SYohann Dudouit     slice[i + tidz*Q1D] = 0.0;
51c532df63SYohann }
52c532df63SYohann 
53*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
54*ab213215SJeremy L Thompson // Write DoFs
55*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
56c532df63SYohann inline __device__ void writeDofs1d(const int elem, const int tidx,
57c532df63SYohann                                    const int tidy, const int comp,
58288c0443SJeremy L Thompson                                    const int nelem, const CeedScalar &r_V,
59288c0443SJeremy L Thompson                                    CeedScalar *d_V) {
60*ab213215SJeremy L Thompson   if (tidx<P1D)
61c532df63SYohann     d_V[tidx + comp*P1D + elem*BASIS_NCOMP*P1D] = r_V;
62c532df63SYohann }
63c532df63SYohann 
64*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
65*ab213215SJeremy L Thompson // Read quadrature point data
66*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
67c532df63SYohann inline __device__ void readQuads1d(const int elem, const int tidx,
68d94769d2SYohann Dudouit                                    const int tidy, const int tidz, const int comp,
69288c0443SJeremy L Thompson                                    const int dim, const int nelem,
70288c0443SJeremy L Thompson                                    const CeedScalar *d_U, CeedScalar *slice) {
71c532df63SYohann   for (int i = 0; i < Q1D; i++)
724d537eeaSYohann     slice[i + tidz*Q1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
734d537eeaSYohann                             dim*BASIS_NCOMP*nelem*Q1D];
74c532df63SYohann }
75c532df63SYohann 
76*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
77*ab213215SJeremy L Thompson // Write quadrature point data
78*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
79c532df63SYohann inline __device__ void writeQuads1d(const int elem, const int tidx,
80c532df63SYohann                                     const int tidy, const int comp,
81288c0443SJeremy L Thompson                                     const int dim, const int nelem,
82288c0443SJeremy L Thompson                                     const CeedScalar &r_V, CeedScalar *d_V) {
83c532df63SYohann   d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
84c532df63SYohann }
85c532df63SYohann 
86*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
87*ab213215SJeremy L Thompson // 1D tensor contraction
88*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
89c532df63SYohann inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
90d94769d2SYohann Dudouit                                    const int tidy, const int tidz,
91288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
92288c0443SJeremy L Thompson                                    CeedScalar &V) {
93c532df63SYohann   V = 0.0;
94*ab213215SJeremy L Thompson   for (int i = 0; i < P1D; ++i)
95*ab213215SJeremy L Thompson     V += B[i + tidx*P1D] * slice[i + tidz*Q1D]; // Contract x direction
96c532df63SYohann }
97c532df63SYohann 
98*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
99*ab213215SJeremy L Thompson // 1D transpose tensor contraction
100*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
101c532df63SYohann inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
102d94769d2SYohann Dudouit     const int tidy, const int tidz,
103c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
104c532df63SYohann   V = 0.0;
105*ab213215SJeremy L Thompson   for (int i = 0; i < Q1D; ++i)
106*ab213215SJeremy L Thompson     V += B[tidx + i*P1D] * slice[i + tidz*Q1D]; // Contract x direction
107c532df63SYohann }
108c532df63SYohann 
109*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
110*ab213215SJeremy L Thompson // 1D interpolate to quadrature points
111*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
112c532df63SYohann inline __device__ void interp1d(const CeedInt nelem, const int transpose,
113288c0443SJeremy L Thompson                                 const CeedScalar *c_B,
114288c0443SJeremy L Thompson                                 const CeedScalar *__restrict__ d_U,
115c532df63SYohann                                 CeedScalar *__restrict__ d_V,
116c532df63SYohann                                 CeedScalar *slice) {
117c532df63SYohann   CeedScalar r_V;
118c532df63SYohann   CeedScalar r_t;
119c532df63SYohann 
120c532df63SYohann   const int tidx = threadIdx.x;
121c532df63SYohann   const int tidy = threadIdx.y;
122d94769d2SYohann Dudouit   const int tidz = threadIdx.z;
123c532df63SYohann 
124c532df63SYohann 
125c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
126c532df63SYohann        elem += gridDim.x*blockDim.z) {
127c532df63SYohann     for (int comp = 0; comp < BASIS_NCOMP; comp++) {
128c532df63SYohann       if (!transpose) {
129d94769d2SYohann Dudouit         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
130d94769d2SYohann Dudouit         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
131c532df63SYohann         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
132c532df63SYohann       } else {
133d94769d2SYohann Dudouit         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
134d94769d2SYohann Dudouit         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
135c532df63SYohann         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
136c532df63SYohann       }
137c532df63SYohann     }
138c532df63SYohann   }
139c532df63SYohann }
140c532df63SYohann 
141*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
142*ab213215SJeremy L Thompson // 1D derivatives at quadrature points
143*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
144c532df63SYohann inline __device__ void grad1d(const CeedInt nelem, const int transpose,
145c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
146288c0443SJeremy L Thompson                               const CeedScalar *__restrict__ d_U,
147288c0443SJeremy L Thompson                               CeedScalar *__restrict__ d_V,
148c532df63SYohann                               CeedScalar *slice) {
149c532df63SYohann   CeedScalar r_U;
150c532df63SYohann   CeedScalar r_V;
151c532df63SYohann 
152c532df63SYohann   const int tidx = threadIdx.x;
153d94769d2SYohann Dudouit   const int tidy = threadIdx.y;
154d94769d2SYohann Dudouit   const int tidz = threadIdx.z;
155c532df63SYohann   int dim;
156c532df63SYohann 
157c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
158c532df63SYohann        elem += gridDim.x*blockDim.z) {
159c532df63SYohann     for(int comp = 0; comp < BASIS_NCOMP; comp++) {
160c532df63SYohann       if (!transpose) {
161d94769d2SYohann Dudouit         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
162d94769d2SYohann Dudouit         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
163c532df63SYohann         dim = 0;
164c532df63SYohann         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
165c532df63SYohann       } else {
166c532df63SYohann         dim = 0;
167d94769d2SYohann Dudouit         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
168d94769d2SYohann Dudouit         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
169c532df63SYohann         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
170c532df63SYohann       }
171c532df63SYohann     }
172c532df63SYohann   }
173c532df63SYohann }
174c532df63SYohann 
175*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
176*ab213215SJeremy L Thompson // 1D Quadrature weights
177*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
178*ab213215SJeremy L Thompson __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
179*ab213215SJeremy L Thompson                          CeedScalar *w) {
180*ab213215SJeremy L Thompson   const int tid = threadIdx.x;
181*ab213215SJeremy L Thompson   const CeedScalar weight = qweight1d[tid];
182*ab213215SJeremy L Thompson   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
183*ab213215SJeremy L Thompson        elem += gridDim.x*blockDim.y) {
184*ab213215SJeremy L Thompson     const int ind = elem*Q1D + tid;
185*ab213215SJeremy L Thompson     w[ind] = weight;
186*ab213215SJeremy L Thompson   }
187*ab213215SJeremy L Thompson }
188*ab213215SJeremy L Thompson 
189*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
190*ab213215SJeremy L Thompson // 2D
191*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
192*ab213215SJeremy L Thompson 
193*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
194*ab213215SJeremy L Thompson // Read DoFs
195*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
196c532df63SYohann inline __device__ void readDofs2d(const int elem, const int tidx,
197c532df63SYohann                                   const int tidy, const int comp,
198288c0443SJeremy L Thompson                                   const int nelem, const CeedScalar *d_U,
199288c0443SJeremy L Thompson                                   CeedScalar &U) {
200c532df63SYohann   U = (tidx<P1D
2017f823360Sjeremylt        && tidy<P1D) ? d_U[tidx + tidy*P1D + comp*P1D*P1D +
2027f823360Sjeremylt                           elem*BASIS_NCOMP*P1D*P1D] :
203c532df63SYohann       0.0;
204c532df63SYohann }
205c532df63SYohann 
206*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
207*ab213215SJeremy L Thompson // Write DoFs
208*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
209c532df63SYohann inline __device__ void writeDofs2d(const int elem, const int tidx,
210c532df63SYohann                                    const int tidy, const int comp,
211288c0443SJeremy L Thompson                                    const int nelem, const CeedScalar &r_V,
212288c0443SJeremy L Thompson                                    CeedScalar *d_V) {
213*ab213215SJeremy L Thompson   if (tidx<P1D && tidy<P1D)
214c532df63SYohann     d_V[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D] = r_V;
215c532df63SYohann }
216c532df63SYohann 
217*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
218*ab213215SJeremy L Thompson // Read quadrature point data
219*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
220c532df63SYohann inline __device__ void readQuads2d(const int elem, const int tidx,
221c532df63SYohann                                    const int tidy, const int comp,
222288c0443SJeremy L Thompson                                    const int dim, const int nelem,
223288c0443SJeremy L Thompson                                    const CeedScalar *d_U, CeedScalar &U ) {
224c532df63SYohann   U = d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
225c532df63SYohann                dim*BASIS_NCOMP*nelem*Q1D*Q1D];
226c532df63SYohann }
227c532df63SYohann 
228*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
229*ab213215SJeremy L Thompson // Write quadrature point data
230*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
231c532df63SYohann inline __device__ void writeQuads2d(const int elem, const int tidx,
232c532df63SYohann                                     const int tidy, const int comp,
233288c0443SJeremy L Thompson                                     const int dim, const int nelem,
234288c0443SJeremy L Thompson                                     const CeedScalar &r_V, CeedScalar *d_V) {
235c532df63SYohann   d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
236c532df63SYohann            dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V;
237c532df63SYohann }
238c532df63SYohann 
239*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
240*ab213215SJeremy L Thompson // 2D tensor contraction x
241*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
242c532df63SYohann inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
2434247ecf3SYohann Dudouit                                    const int tidy, const int tidz,
244288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
245288c0443SJeremy L Thompson                                    CeedScalar &V) {
2464247ecf3SYohann Dudouit   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
247c532df63SYohann   __syncthreads();
248c532df63SYohann   V = 0.0;
249*ab213215SJeremy L Thompson   for (int i = 0; i < P1D; ++i)
250*ab213215SJeremy L Thompson     V += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
251c532df63SYohann   __syncthreads();
252c532df63SYohann }
253c532df63SYohann 
254*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
255*ab213215SJeremy L Thompson // 2D tensor contraction y
256*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
257c532df63SYohann inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
2584247ecf3SYohann Dudouit                                    const int tidy, const int tidz,
259288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
260288c0443SJeremy L Thompson                                    CeedScalar &V) {
2614247ecf3SYohann Dudouit   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
262c532df63SYohann   __syncthreads();
263c532df63SYohann   V = 0.0;
264*ab213215SJeremy L Thompson   for (int i = 0; i < P1D; ++i)
265*ab213215SJeremy L Thompson     V += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
266c532df63SYohann   __syncthreads();
267c532df63SYohann }
268c532df63SYohann 
269*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
270*ab213215SJeremy L Thompson // 2D transpose tensor contraction y
271*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
272c532df63SYohann inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
2734247ecf3SYohann Dudouit     const int tidy, const int tidz,
274c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
2754247ecf3SYohann Dudouit   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
276c532df63SYohann   __syncthreads();
277c532df63SYohann   V = 0.0;
278*ab213215SJeremy L Thompson   if (tidy < P1D)
279*ab213215SJeremy L Thompson     for (int i = 0; i < Q1D; ++i)
280*ab213215SJeremy L Thompson       V += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
281c532df63SYohann   __syncthreads();
282c532df63SYohann }
283c532df63SYohann 
284*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
285*ab213215SJeremy L Thompson // 2D transpose tensor contraction x
286*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
287c532df63SYohann inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
2884247ecf3SYohann Dudouit     const int tidy, const int tidz,
289c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
2904247ecf3SYohann Dudouit   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
291c532df63SYohann   __syncthreads();
292c532df63SYohann   V = 0.0;
293*ab213215SJeremy L Thompson   if (tidx < P1D)
294*ab213215SJeremy L Thompson     for (int i = 0; i < Q1D; ++i)
295*ab213215SJeremy L Thompson       V += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
296c532df63SYohann   __syncthreads();
297c532df63SYohann }
298c532df63SYohann 
299*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
300*ab213215SJeremy L Thompson // 2D interpolate to quadrature points
301*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
302c532df63SYohann inline __device__ void interp2d(const CeedInt nelem, const int transpose,
303288c0443SJeremy L Thompson                                 const CeedScalar *c_B,
304288c0443SJeremy L Thompson                                 const CeedScalar *__restrict__ d_U,
305c532df63SYohann                                 CeedScalar *__restrict__ d_V,
306c532df63SYohann                                 CeedScalar *slice) {
307c532df63SYohann   CeedScalar r_V;
308c532df63SYohann   CeedScalar r_t;
309c532df63SYohann 
310c532df63SYohann   const int tidx = threadIdx.x;
311c532df63SYohann   const int tidy = threadIdx.y;
3124247ecf3SYohann Dudouit   const int tidz = threadIdx.z;
3134247ecf3SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
3144247ecf3SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
3154247ecf3SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
316c532df63SYohann 
3174247ecf3SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
3184247ecf3SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
3194247ecf3SYohann Dudouit     const int comp = tidz%BASIS_NCOMP;
320c532df63SYohann     r_V = 0.0;
321c532df63SYohann     r_t = 0.0;
322c532df63SYohann     if (!transpose) {
323c532df63SYohann       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
3244247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
3254247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
326c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
327c532df63SYohann     } else {
328c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
3294247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
3304247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
331c532df63SYohann       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
332c532df63SYohann     }
333c532df63SYohann   }
334c532df63SYohann }
335c532df63SYohann 
336*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
337*ab213215SJeremy L Thompson // 2D derivatives at quadrature points
338*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
339c532df63SYohann inline __device__ void grad2d(const CeedInt nelem, const int transpose,
340c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
3417f823360Sjeremylt                               const CeedScalar *__restrict__ d_U,
3427f823360Sjeremylt                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
343c532df63SYohann   CeedScalar r_U;
344c532df63SYohann   CeedScalar r_V;
345c532df63SYohann   CeedScalar r_t;
346c532df63SYohann 
347c532df63SYohann   const int tidx = threadIdx.x;
348c532df63SYohann   const int tidy = threadIdx.y;
3494247ecf3SYohann Dudouit   const int tidz = threadIdx.z;
3504247ecf3SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
3514247ecf3SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
3524247ecf3SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
353c532df63SYohann   int dim;
354c532df63SYohann 
3554247ecf3SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
3564247ecf3SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
357c532df63SYohann     if (!transpose) {
358c532df63SYohann       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
3594247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
3604247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
361c532df63SYohann       dim = 0;
362c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
3634247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
3644247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
365c532df63SYohann       dim = 1;
366c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
367c532df63SYohann     } else {
368c532df63SYohann       dim = 0;
369c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
3704247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
3714247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
372c532df63SYohann       dim = 1;
373c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
3744247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
3754247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
376c532df63SYohann       r_V += r_U;
377c532df63SYohann       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
378c532df63SYohann     }
379c532df63SYohann   }
380c532df63SYohann }
381c532df63SYohann 
382*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
383*ab213215SJeremy L Thompson // 2D quadrature weights
384*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
385*ab213215SJeremy L Thompson __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
386*ab213215SJeremy L Thompson                          CeedScalar *w) {
387*ab213215SJeremy L Thompson   const int i = threadIdx.x;
388*ab213215SJeremy L Thompson   const int j = threadIdx.y;
389*ab213215SJeremy L Thompson   const CeedScalar weight = qweight1d[i]*qweight1d[j];
390*ab213215SJeremy L Thompson   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
391*ab213215SJeremy L Thompson        elem += gridDim.x*blockDim.z) {
392*ab213215SJeremy L Thompson     const int ind = elem*Q1D*Q1D + i + j*Q1D;
393*ab213215SJeremy L Thompson     w[ind] = weight;
394*ab213215SJeremy L Thompson   }
395*ab213215SJeremy L Thompson }
396*ab213215SJeremy L Thompson 
397*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
398*ab213215SJeremy L Thompson // 3D
399*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
400*ab213215SJeremy L Thompson 
401*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
402*ab213215SJeremy L Thompson // Read DoFs
403*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
404c532df63SYohann inline __device__ void readDofs3d(const int elem, const int tidx,
405c532df63SYohann                                   const int tidy, const int comp,
4067f823360Sjeremylt                                   const int nelem, const CeedScalar *d_U,
4077f823360Sjeremylt                                   CeedScalar *r_U) {
408c532df63SYohann   for (int i = 0; i < P1D; i++)
409*ab213215SJeremy L Thompson     r_U[i] = (tidx < P1D && tidy < P1D) ?
410*ab213215SJeremy L Thompson               d_U[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
411c532df63SYohann                                       elem*BASIS_NCOMP*P1D*P1D*P1D] : 0.0;
412c532df63SYohann   for (int i = P1D; i < Q1D; i++)
413c532df63SYohann     r_U[i] = 0.0;
414c532df63SYohann }
415c532df63SYohann 
416*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
417*ab213215SJeremy L Thompson // Read quadrature point data
418*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
419c532df63SYohann inline __device__ void readQuads3d(const int elem, const int tidx,
420c532df63SYohann                                    const int tidy, const int comp,
4217f823360Sjeremylt                                    const int dim, const int nelem,
4227f823360Sjeremylt                                    const CeedScalar *d_U, CeedScalar *r_U) {
423c532df63SYohann   for (int i = 0; i < Q1D; i++)
424c532df63SYohann     r_U[i] = d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
425c532df63SYohann                  comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D];
426c532df63SYohann }
427c532df63SYohann 
428*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
429*ab213215SJeremy L Thompson // Write DoFs
430*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
431c532df63SYohann inline __device__ void writeDofs3d(const int elem, const int tidx,
432c532df63SYohann                                    const int tidy, const int comp,
4337f823360Sjeremylt                                    const int nelem, const CeedScalar *r_V,
4347f823360Sjeremylt                                    CeedScalar *d_V) {
435c532df63SYohann   if (tidx < P1D && tidy < P1D) {
436c532df63SYohann     for (int i = 0; i < P1D; i++)
437c532df63SYohann       d_V[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
438c532df63SYohann           elem*BASIS_NCOMP*P1D*P1D*P1D] = r_V[i];
439c532df63SYohann   }
440c532df63SYohann }
441c532df63SYohann 
442*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
443*ab213215SJeremy L Thompson // Write quadrature point data
444*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
445c532df63SYohann inline __device__ void writeQuads3d(const int elem, const int tidx,
446c532df63SYohann                                     const int tidy, const int comp,
4477f823360Sjeremylt                                     const int dim, const int nelem,
4487f823360Sjeremylt                                     const CeedScalar *r_V, CeedScalar *d_V) {
449c532df63SYohann   for (int i = 0; i < Q1D; i++)
450c532df63SYohann     d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
451c532df63SYohann         dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i];
452c532df63SYohann }
453c532df63SYohann 
454*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
455*ab213215SJeremy L Thompson // 3D tensor contract x
456*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
457c532df63SYohann inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
458698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
4597f823360Sjeremylt                                    const CeedScalar *U, const CeedScalar *B,
4607f823360Sjeremylt                                    CeedScalar *V) {
461c532df63SYohann   for (int k = 0; k < P1D; ++k) {
462698ebc35SYohann Dudouit     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
463c532df63SYohann     __syncthreads();
464c532df63SYohann     V[k] = 0.0;
465*ab213215SJeremy L Thompson     for (int i = 0; i < P1D; ++i)
466*ab213215SJeremy L Thompson       V[k] += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
467c532df63SYohann     __syncthreads();
468c532df63SYohann   }
469c532df63SYohann }
470c532df63SYohann 
471*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
472*ab213215SJeremy L Thompson // 3D tensor contract y
473*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
474c532df63SYohann inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
475698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
4767f823360Sjeremylt                                    const CeedScalar *U, const CeedScalar *B,
4777f823360Sjeremylt                                    CeedScalar *V) {
478c532df63SYohann   for (int k = 0; k < P1D; ++k) {
479698ebc35SYohann Dudouit     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
480c532df63SYohann     __syncthreads();
481c532df63SYohann     V[k] = 0.0;
482*ab213215SJeremy L Thompson     for (int i = 0; i < P1D; ++i)
483*ab213215SJeremy L Thompson       V[k] += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
484c532df63SYohann     __syncthreads();
485c532df63SYohann   }
486c532df63SYohann }
487c532df63SYohann 
488*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
489*ab213215SJeremy L Thompson // 3D tensor contract z
490*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
491c532df63SYohann inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
492698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
4937f823360Sjeremylt                                    const CeedScalar *U, const CeedScalar *B,
4947f823360Sjeremylt                                    CeedScalar *V) {
495c532df63SYohann   for (int k = 0; k < Q1D; ++k) {
496c532df63SYohann     V[k] = 0.0;
497*ab213215SJeremy L Thompson     for (int i = 0; i < P1D; ++i)
498*ab213215SJeremy L Thompson       V[k] += B[i + k*P1D] * U[i]; // Contract z direction
499c532df63SYohann   }
500c532df63SYohann }
501c532df63SYohann 
502*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
503*ab213215SJeremy L Thompson // 3D transpose tensor contract z
504*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
505c532df63SYohann inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
506698ebc35SYohann Dudouit     const int tidy, const int tidz,
507c532df63SYohann     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
508c532df63SYohann   for (int k = 0; k < Q1D; ++k) {
509c532df63SYohann     V[k] = 0.0;
510*ab213215SJeremy L Thompson     if (k < P1D)
511*ab213215SJeremy L Thompson       for (int i = 0; i < Q1D; ++i)
512*ab213215SJeremy L Thompson         V[k] += B[k + i*P1D] * U[i]; // Contract z direction
513c532df63SYohann   }
514c532df63SYohann }
515c532df63SYohann 
516*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
517*ab213215SJeremy L Thompson // 3D transpose tensor contract y
518*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
519c532df63SYohann inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
520698ebc35SYohann Dudouit     const int tidy, const int tidz,
521c532df63SYohann     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
522c532df63SYohann   for (int k = 0; k < P1D; ++k) {
523698ebc35SYohann Dudouit     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
524c532df63SYohann     __syncthreads();
525c532df63SYohann     V[k] = 0.0;
526*ab213215SJeremy L Thompson     if (tidy < P1D)
527*ab213215SJeremy L Thompson       for (int i = 0; i < Q1D; ++i)
528*ab213215SJeremy L Thompson         V[k] += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
529c532df63SYohann     __syncthreads();
530c532df63SYohann   }
531c532df63SYohann }
532c532df63SYohann 
533*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
534*ab213215SJeremy L Thompson // 3D transpose tensor contract x
535*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
536c532df63SYohann inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
537698ebc35SYohann Dudouit     const int tidy, const int tidz,
538c532df63SYohann     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
539c532df63SYohann   for (int k = 0; k < P1D; ++k) {
540698ebc35SYohann Dudouit     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
541c532df63SYohann     __syncthreads();
542c532df63SYohann     V[k] = 0.0;
543*ab213215SJeremy L Thompson     if (tidx < P1D)
544*ab213215SJeremy L Thompson       for (int i = 0; i < Q1D; ++i)
545*ab213215SJeremy L Thompson         V[k] += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
546c532df63SYohann     __syncthreads();
547c532df63SYohann   }
548c532df63SYohann }
549c532df63SYohann 
550*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
551*ab213215SJeremy L Thompson // 3D interpolate to quadrature points
552*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
553c532df63SYohann inline __device__ void interp3d(const CeedInt nelem, const int transpose,
5547f823360Sjeremylt                                 const CeedScalar *c_B,
5557f823360Sjeremylt                                 const CeedScalar *__restrict__ d_U,
556c532df63SYohann                                 CeedScalar *__restrict__ d_V,
557c532df63SYohann                                 CeedScalar *slice) {
558c532df63SYohann   CeedScalar r_V[Q1D];
559c532df63SYohann   CeedScalar r_t[Q1D];
560c532df63SYohann 
561c532df63SYohann   const int tidx = threadIdx.x;
562c532df63SYohann   const int tidy = threadIdx.y;
563698ebc35SYohann Dudouit   const int tidz = threadIdx.z;
564698ebc35SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
565698ebc35SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
566698ebc35SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
567c532df63SYohann 
568698ebc35SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
569698ebc35SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
570c532df63SYohann     for (int i = 0; i < Q1D; ++i) {
571c532df63SYohann       r_V[i] = 0.0;
572c532df63SYohann       r_t[i] = 0.0;
573c532df63SYohann     }
574c532df63SYohann     if (!transpose) {
575c532df63SYohann       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
576698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
577698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
578698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
579c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
580c532df63SYohann     } else {
581c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
582698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
583698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
584698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
585c532df63SYohann       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
586c532df63SYohann     }
587c532df63SYohann   }
588c532df63SYohann }
589c532df63SYohann 
590*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
591*ab213215SJeremy L Thompson // 3D derivatives at quadrature points
592*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
593c532df63SYohann inline __device__ void grad3d(const CeedInt nelem, const int transpose,
594c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
5957f823360Sjeremylt                               const CeedScalar *__restrict__ d_U,
5967f823360Sjeremylt                               CeedScalar *__restrict__ d_V,
597c532df63SYohann                               CeedScalar *slice) {
598*ab213215SJeremy L Thompson   // Use P1D for one of these
599c532df63SYohann   CeedScalar r_U[Q1D];
600c532df63SYohann   CeedScalar r_V[Q1D];
601c532df63SYohann   CeedScalar r_t[Q1D];
602c532df63SYohann 
603c532df63SYohann   const int tidx = threadIdx.x;
604c532df63SYohann   const int tidy = threadIdx.y;
605698ebc35SYohann Dudouit   const int tidz = threadIdx.z;
606698ebc35SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
607698ebc35SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
608698ebc35SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
609c532df63SYohann   int dim;
610c532df63SYohann 
611698ebc35SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
612698ebc35SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
613c532df63SYohann     if (!transpose) {
614c532df63SYohann       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
615698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
616698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
617698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
618c532df63SYohann       dim = 0;
619c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
620698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
621698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
622698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
623c532df63SYohann       dim = 1;
624c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
625698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
626698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
627698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
628c532df63SYohann       dim = 2;
629c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
630c532df63SYohann     } else {
631c532df63SYohann       dim = 0;
632c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
633698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
634698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
635698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
636c532df63SYohann       dim = 1;
637c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
638698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
639698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
640698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
641c532df63SYohann       add(r_V, r_t);
642c532df63SYohann       dim = 2;
643c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
644698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
645698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
646698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
647c532df63SYohann       add(r_V, r_t);
648c532df63SYohann       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
649c532df63SYohann     }
650c532df63SYohann   }
651c532df63SYohann }
652c532df63SYohann 
653*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
654*ab213215SJeremy L Thompson // 3D quadrature weights
655*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
656*ab213215SJeremy L Thompson __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
657*ab213215SJeremy L Thompson                          CeedScalar *w) {
658*ab213215SJeremy L Thompson   const int i = threadIdx.x;
659*ab213215SJeremy L Thompson   const int j = threadIdx.y;
660*ab213215SJeremy L Thompson   const int k = threadIdx.z;
661*ab213215SJeremy L Thompson   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
662*ab213215SJeremy L Thompson   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
663*ab213215SJeremy L Thompson     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
664*ab213215SJeremy L Thompson     w[ind] = weight;
665*ab213215SJeremy L Thompson   }
666*ab213215SJeremy L Thompson }
667*ab213215SJeremy L Thompson 
668*ab213215SJeremy L Thompson 
669*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
670*ab213215SJeremy L Thompson // Basis kernels
671*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
672*ab213215SJeremy L Thompson 
673*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
674*ab213215SJeremy L Thompson // Interp kernel by dim
675*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
676c532df63SYohann extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
6777f823360Sjeremylt                                   const CeedScalar *c_B,
6787f823360Sjeremylt                                   const CeedScalar *__restrict__ d_U,
679c532df63SYohann                                   CeedScalar *__restrict__ d_V) {
680074be161SYohann Dudouit   extern __shared__ double slice[];
681c532df63SYohann   if (BASIS_DIM == 1) {
682c532df63SYohann     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
683c532df63SYohann   } else if (BASIS_DIM == 2) {
684c532df63SYohann     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
685c532df63SYohann   } else if (BASIS_DIM == 3) {
686c532df63SYohann     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
687c532df63SYohann   }
688c532df63SYohann }
689c532df63SYohann 
690*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
691*ab213215SJeremy L Thompson // Grad kernel by dim
692*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
693c532df63SYohann extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
694c532df63SYohann                                 const CeedScalar *c_B, const CeedScalar *c_G,
6957f823360Sjeremylt                                 const CeedScalar *__restrict__ d_U,
6967f823360Sjeremylt                                 CeedScalar *__restrict__ d_V) {
697074be161SYohann Dudouit   extern __shared__ double slice[];
698c532df63SYohann   if (BASIS_DIM == 1) {
699c532df63SYohann     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
700c532df63SYohann   } else if (BASIS_DIM == 2) {
701c532df63SYohann     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
702c532df63SYohann   } else if (BASIS_DIM == 3) {
703c532df63SYohann     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
704c532df63SYohann   }
705c532df63SYohann }
706c532df63SYohann 
707*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
708*ab213215SJeremy L Thompson // Weight kernels by dim
709*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
710c532df63SYohann extern "C" __global__ void weight(const CeedInt nelem,
7117f823360Sjeremylt                                   const CeedScalar *__restrict__ qweight1d,
7127f823360Sjeremylt                                   CeedScalar *__restrict__ v) {
713c532df63SYohann   if (BASIS_DIM == 1) {
714c532df63SYohann     weight1d(nelem, qweight1d, v);
715c532df63SYohann   } else if (BASIS_DIM == 2) {
716c532df63SYohann     weight2d(nelem, qweight1d, v);
717c532df63SYohann   } else if (BASIS_DIM == 3) {
718c532df63SYohann     weight3d(nelem, qweight1d, v);
719c532df63SYohann   }
720c532df63SYohann }
721c532df63SYohann 
722c532df63SYohann );
723cb0b5415Sjeremylt // *INDENT-ON*
724c532df63SYohann 
725*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
726*ab213215SJeremy L Thompson // Device initalization
727*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
728c532df63SYohann int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
729c532df63SYohann                        CeedScalar **c_B);
730c532df63SYohann int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
7317f823360Sjeremylt                            CeedInt Q1d, CeedScalar **c_B_ptr,
7327f823360Sjeremylt                            CeedScalar **c_G_ptr);
733c532df63SYohann 
734*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
735*ab213215SJeremy L Thompson // Apply basis
736*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
737c532df63SYohann int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem,
738c532df63SYohann                                      CeedTransposeMode tmode,
7397f823360Sjeremylt                                      CeedEvalMode emode, CeedVector u,
7407f823360Sjeremylt                                      CeedVector v) {
741c532df63SYohann   int ierr;
742c532df63SYohann   Ceed ceed;
743c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
744c532df63SYohann   Ceed_Cuda_shared *ceed_Cuda;
745c532df63SYohann   CeedGetData(ceed, (void *) &ceed_Cuda); CeedChk(ierr);
746c532df63SYohann   CeedBasis_Cuda_shared *data;
747c532df63SYohann   CeedBasisGetData(basis, (void *)&data); CeedChk(ierr);
748c532df63SYohann   const CeedInt transpose = tmode == CEED_TRANSPOSE;
7494247ecf3SYohann Dudouit   CeedInt dim, ncomp;
750074be161SYohann Dudouit   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
7514247ecf3SYohann Dudouit   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
752c532df63SYohann 
753*ab213215SJeremy L Thompson   // Read vectors
754c532df63SYohann   const CeedScalar *d_u;
755c532df63SYohann   CeedScalar *d_v;
756c532df63SYohann   if (emode != CEED_EVAL_WEIGHT) {
757c532df63SYohann     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
758c532df63SYohann   }
759c532df63SYohann   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
760c532df63SYohann 
761*ab213215SJeremy L Thompson   // Clear v for transpose mode
762c532df63SYohann   if (tmode == CEED_TRANSPOSE) {
763c532df63SYohann     CeedInt length;
764c532df63SYohann     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
765c532df63SYohann     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
766c532df63SYohann   }
767*ab213215SJeremy L Thompson 
768*ab213215SJeremy L Thompson   // Apply basis operation
769*ab213215SJeremy L Thompson   switch (emode) {
770*ab213215SJeremy L Thompson   case CEED_EVAL_INTERP: {
771c532df63SYohann     CeedInt P1d, Q1d;
772c532df63SYohann     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
773c532df63SYohann     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
774c532df63SYohann     ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
775c532df63SYohann     CeedChk(ierr);
776cb0b5415Sjeremylt     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
777ccf0fe6fSjeremylt                           &d_u, &d_v
778ccf0fe6fSjeremylt                          };
7794d537eeaSYohann     if (dim == 1) {
780d94769d2SYohann Dudouit       CeedInt elemsPerBlock = 32;
7814d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
7824d537eeaSYohann                                              ? 1 : 0 );
783d94769d2SYohann Dudouit       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
7844d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, 1,
7854d537eeaSYohann                                         elemsPerBlock, sharedMem,
786*ab213215SJeremy L Thompson                                         interpargs); CeedChk(ierr);
787074be161SYohann Dudouit     } else if (dim == 2) {
7884247ecf3SYohann Dudouit       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
7894247ecf3SYohann Dudouit       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
7904d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
7914d537eeaSYohann                                              ? 1 : 0 );
7924247ecf3SYohann Dudouit       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
7934d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d,
7944d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
795*ab213215SJeremy L Thompson                                         interpargs); CeedChk(ierr);
796074be161SYohann Dudouit     } else if (dim == 3) {
7973f63d318SYohann Dudouit       CeedInt elemsPerBlock = 1;
7984d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
7994d537eeaSYohann                                              ? 1 : 0 );
800698ebc35SYohann Dudouit       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
8014d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d,
8024d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
803*ab213215SJeremy L Thompson                                         interpargs); CeedChk(ierr);
804074be161SYohann Dudouit     }
805*ab213215SJeremy L Thompson   } break;
806*ab213215SJeremy L Thompson   case CEED_EVAL_GRAD: {
807c532df63SYohann     CeedInt P1d, Q1d;
808c532df63SYohann     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
809c532df63SYohann     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
810c532df63SYohann     ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
811c532df63SYohann                                   Q1d, &data->c_B, &data->c_G);
812c532df63SYohann     CeedChk(ierr);
813cb0b5415Sjeremylt     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
814ccf0fe6fSjeremylt                         &data->c_G, &d_u, &d_v
815ccf0fe6fSjeremylt                        };
8164d537eeaSYohann     if (dim == 1) {
817d94769d2SYohann Dudouit       CeedInt elemsPerBlock = 32;
8184d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8194d537eeaSYohann                                              ? 1 : 0 );
820d94769d2SYohann Dudouit       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
821*ab213215SJeremy L Thompson       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, 1,
822*ab213215SJeremy L Thompson                                         elemsPerBlock, sharedMem, gradargs);
823c532df63SYohann       CeedChk(ierr);
824074be161SYohann Dudouit     } else if (dim == 2) {
8254247ecf3SYohann Dudouit       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
8264247ecf3SYohann Dudouit       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
8274d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8284d537eeaSYohann                                              ? 1 : 0 );
8294247ecf3SYohann Dudouit       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
8304d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d,
8314d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
832*ab213215SJeremy L Thompson                                         gradargs); CeedChk(ierr);
833074be161SYohann Dudouit     } else if (dim == 3) {
8343f63d318SYohann Dudouit       CeedInt elemsPerBlock = 1;
8354d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8364d537eeaSYohann                                              ? 1 : 0 );
837698ebc35SYohann Dudouit       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
8384d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d,
8394d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
840*ab213215SJeremy L Thompson                                         gradargs); CeedChk(ierr);
841074be161SYohann Dudouit     }
842*ab213215SJeremy L Thompson   } break;
843*ab213215SJeremy L Thompson   case CEED_EVAL_WEIGHT: {
844074be161SYohann Dudouit     CeedInt Q1d;
845074be161SYohann Dudouit     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
846c532df63SYohann     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
847074be161SYohann Dudouit     if (dim == 1) {
848074be161SYohann Dudouit       const CeedInt elemsPerBlock = 32/Q1d;
8494d537eeaSYohann       const CeedInt gridsize = nelem/elemsPerBlock + ( (
8504d537eeaSYohann                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
8517f823360Sjeremylt       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d,
8527f823360Sjeremylt                                   elemsPerBlock, 1, weightargs);
8531226057fSYohann Dudouit       CeedChk(ierr);
854074be161SYohann Dudouit     } else if (dim == 2) {
855717ff8a3SYohann Dudouit       const CeedInt optElems = 32/(Q1d*Q1d);
856717ff8a3SYohann Dudouit       const CeedInt elemsPerBlock = optElems>0?optElems:1;
8574d537eeaSYohann       const CeedInt gridsize = nelem/elemsPerBlock + ( (
8584d537eeaSYohann                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
8594d537eeaSYohann       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d,
8604d537eeaSYohann                                   elemsPerBlock, weightargs);
8611226057fSYohann Dudouit       CeedChk(ierr);
862074be161SYohann Dudouit     } else if (dim == 3) {
863074be161SYohann Dudouit       const CeedInt gridsize = nelem;
8644d537eeaSYohann       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
8654d537eeaSYohann                                   weightargs);
8661226057fSYohann Dudouit       CeedChk(ierr);
867074be161SYohann Dudouit     }
868*ab213215SJeremy L Thompson   } break;
869*ab213215SJeremy L Thompson   // LCOV_EXCL_START
870*ab213215SJeremy L Thompson   // Evaluate the divergence to/from the quadrature points
871*ab213215SJeremy L Thompson   case CEED_EVAL_DIV:
872*ab213215SJeremy L Thompson     return CeedError(ceed, 1, "CEED_EVAL_DIV not supported");
873*ab213215SJeremy L Thompson   // Evaluate the curl to/from the quadrature points
874*ab213215SJeremy L Thompson   case CEED_EVAL_CURL:
875*ab213215SJeremy L Thompson     return CeedError(ceed, 1, "CEED_EVAL_CURL not supported");
876*ab213215SJeremy L Thompson   // Take no action, BasisApply should not have been called
877*ab213215SJeremy L Thompson   case CEED_EVAL_NONE:
878*ab213215SJeremy L Thompson     return CeedError(ceed, 1,
879*ab213215SJeremy L Thompson                      "CEED_EVAL_NONE does not make sense in this context");
880*ab213215SJeremy L Thompson     // LCOV_EXCL_STOP
881c532df63SYohann   }
882c532df63SYohann 
883*ab213215SJeremy L Thompson   // Restore vectors
884c532df63SYohann   if (emode != CEED_EVAL_WEIGHT) {
885c532df63SYohann     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
886c532df63SYohann   }
887c532df63SYohann   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
888c532df63SYohann   return 0;
889c532df63SYohann }
890c532df63SYohann 
891*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
892*ab213215SJeremy L Thompson // Destroy basis
893*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
894c532df63SYohann static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
895c532df63SYohann   int ierr;
896c532df63SYohann   Ceed ceed;
897c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
898c532df63SYohann 
899c532df63SYohann   CeedBasis_Cuda_shared *data;
900c532df63SYohann   ierr = CeedBasisGetData(basis, (void *) &data); CeedChk(ierr);
901c532df63SYohann 
902c532df63SYohann   CeedChk_Cu(ceed, cuModuleUnload(data->module));
903c532df63SYohann 
904c532df63SYohann   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr);
905c532df63SYohann   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr);
906c532df63SYohann   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr);
907c532df63SYohann 
908c532df63SYohann   ierr = CeedFree(&data); CeedChk(ierr);
909c532df63SYohann 
910c532df63SYohann   return 0;
911c532df63SYohann }
912c532df63SYohann 
913*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
914*ab213215SJeremy L Thompson // Create tensor basis
915*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
916c532df63SYohann int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
917c532df63SYohann                                         const CeedScalar *interp1d,
918c532df63SYohann                                         const CeedScalar *grad1d,
919c532df63SYohann                                         const CeedScalar *qref1d,
920c532df63SYohann                                         const CeedScalar *qweight1d,
921c532df63SYohann                                         CeedBasis basis) {
922c532df63SYohann   int ierr;
923c532df63SYohann   Ceed ceed;
924c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
9254d537eeaSYohann   if (Q1d<P1d) {
9261226057fSYohann Dudouit     return CeedError(ceed, 1, "Backend does not implement underintegrated basis.");
9271226057fSYohann Dudouit   }
928c532df63SYohann   CeedBasis_Cuda_shared *data;
929c532df63SYohann   ierr = CeedCalloc(1, &data); CeedChk(ierr);
930c532df63SYohann 
931*ab213215SJeremy L Thompson   // Copy basis data to GPU
932c532df63SYohann   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
933c532df63SYohann   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr);
934c532df63SYohann   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
935c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
936c532df63SYohann 
937c532df63SYohann   const CeedInt iBytes = qBytes * P1d;
938c532df63SYohann   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr);
939c532df63SYohann   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
940c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
941c532df63SYohann 
942c532df63SYohann   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr);
943c532df63SYohann   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
944c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
945c532df63SYohann 
946*ab213215SJeremy L Thompson   // Compute collocated gradient and copy to GPU
947ac421f39SYohann   data->d_collograd1d = NULL;
948ac421f39SYohann   if (dim == 3 && Q1d >= P1d) {
949ac421f39SYohann     CeedScalar *collograd1d;
950ac421f39SYohann     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChk(ierr);
951ac421f39SYohann     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChk(ierr);
952ac421f39SYohann     ierr = cudaMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
953ac421f39SYohann     CeedChk_Cu(ceed, ierr);
954ac421f39SYohann     ierr = cudaMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
955ac421f39SYohann                       cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
956ac421f39SYohann   }
957ac421f39SYohann 
958*ab213215SJeremy L Thompson   // Compile basis kernels
959c532df63SYohann   CeedInt ncomp;
960c532df63SYohann   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
9614a6d4bbdSYohann Dudouit   ierr = CeedCompileCuda(ceed, kernelsShared, &data->module, 7,
962c532df63SYohann                          "Q1D", Q1d,
963c532df63SYohann                          "P1D", P1d,
964c532df63SYohann                          "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
965c532df63SYohann                              Q1d : P1d, dim),
966c532df63SYohann                          "BASIS_DIM", dim,
967c532df63SYohann                          "BASIS_NCOMP", ncomp,
968c532df63SYohann                          "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
969c532df63SYohann                          "BASIS_NQPT", CeedIntPow(Q1d, dim)
970c532df63SYohann                         ); CeedChk(ierr);
9714a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp);
972c532df63SYohann   CeedChk(ierr);
9734a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad);
974c532df63SYohann   CeedChk(ierr);
9754a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight);
976c532df63SYohann   CeedChk(ierr);
977c532df63SYohann 
978*ab213215SJeremy L Thompson   ierr = CeedBasisSetData(basis, (void *)&data); CeedChk(ierr);
979*ab213215SJeremy L Thompson 
980*ab213215SJeremy L Thompson   // Register backend functions
981c532df63SYohann   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
982c532df63SYohann                                 CeedBasisApplyTensor_Cuda_shared);
983c532df63SYohann   CeedChk(ierr);
984c532df63SYohann   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
985*ab213215SJeremy L Thompson                                 CeedBasisDestroy_Cuda_shared); CeedChk(ierr);
986c532df63SYohann   return 0;
987c532df63SYohann }
988*ab213215SJeremy L Thompson //------------------------------------------------------------------------------
989