xref: /libCEED/rust/libceed-sys/c-src/backends/cuda-shared/ceed-cuda-shared-basis.c (revision d99fa3c5cd91a1690aedf0679cbf290d44fec74c)
1c532df63SYohann // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2c532df63SYohann // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3c532df63SYohann // All Rights reserved. See files LICENSE and NOTICE for details.
4c532df63SYohann //
5c532df63SYohann // This file is part of CEED, a collection of benchmarks, miniapps, software
6c532df63SYohann // libraries and APIs for efficient high-order finite element and spectral
7c532df63SYohann // element discretizations for exascale applications. For more information and
8c532df63SYohann // source code availability see http://github.com/ceed.
9c532df63SYohann //
10c532df63SYohann // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11c532df63SYohann // a collaborative effort of two U.S. Department of Energy organizations (Office
12c532df63SYohann // of Science and the National Nuclear Security Administration) responsible for
13c532df63SYohann // the planning and preparation of a capable exascale ecosystem, including
14c532df63SYohann // software, applications, hardware, advanced system engineering and early
15c532df63SYohann // testbed platforms, in support of the nation's exascale computing imperative.
16c532df63SYohann 
17c532df63SYohann #include "ceed-cuda-shared.h"
18c532df63SYohann 
19ab213215SJeremy L Thompson //------------------------------------------------------------------------------
20ab213215SJeremy L Thompson // Shared mem kernels
21ab213215SJeremy L Thompson //------------------------------------------------------------------------------
22cb0b5415Sjeremylt // *INDENT-OFF*
23c532df63SYohann static const char *kernelsShared = QUOTE(
24c532df63SYohann 
25ab213215SJeremy L Thompson //------------------------------------------------------------------------------
26ab213215SJeremy L Thompson // Sum input into output
27ab213215SJeremy L Thompson //------------------------------------------------------------------------------
28c532df63SYohann inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
29c532df63SYohann   for (int i = 0; i < Q1D; i++)
30c532df63SYohann     r_V[i] += r_U[i];
31c532df63SYohann }
32c532df63SYohann 
33ab213215SJeremy L Thompson //------------------------------------------------------------------------------
34ab213215SJeremy L Thompson // 1D
35ab213215SJeremy L Thompson //------------------------------------------------------------------------------
36c532df63SYohann 
37ab213215SJeremy L Thompson //------------------------------------------------------------------------------
38ab213215SJeremy L Thompson // Read DoFs
39ab213215SJeremy L Thompson //------------------------------------------------------------------------------
40c532df63SYohann inline __device__ void readDofs1d(const int elem, const int tidx,
41d94769d2SYohann Dudouit                                   const int tidy, const int tidz,const int comp,
427f823360Sjeremylt                                   const int nelem, const CeedScalar *d_U,
437f823360Sjeremylt                                   CeedScalar *slice) {
44c532df63SYohann   for (int i = 0; i < P1D; i++)
4549fd234cSJeremy L Thompson     slice[i + tidz*Q1D] = d_U[i + elem*P1D + comp*P1D*nelem];
46c532df63SYohann   for (int i = P1D; i < Q1D; i++)
47d94769d2SYohann Dudouit     slice[i + tidz*Q1D] = 0.0;
48c532df63SYohann }
49c532df63SYohann 
50ab213215SJeremy L Thompson //------------------------------------------------------------------------------
51ab213215SJeremy L Thompson // Write DoFs
52ab213215SJeremy L Thompson //------------------------------------------------------------------------------
53c532df63SYohann inline __device__ void writeDofs1d(const int elem, const int tidx,
54c532df63SYohann                                    const int tidy, const int comp,
55288c0443SJeremy L Thompson                                    const int nelem, const CeedScalar &r_V,
56288c0443SJeremy L Thompson                                    CeedScalar *d_V) {
57ab213215SJeremy L Thompson   if (tidx<P1D)
5849fd234cSJeremy L Thompson     d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V;
59c532df63SYohann }
60c532df63SYohann 
61ab213215SJeremy L Thompson //------------------------------------------------------------------------------
62ab213215SJeremy L Thompson // Read quadrature point data
63ab213215SJeremy L Thompson //------------------------------------------------------------------------------
64c532df63SYohann inline __device__ void readQuads1d(const int elem, const int tidx,
65d94769d2SYohann Dudouit                                    const int tidy, const int tidz, const int comp,
66288c0443SJeremy L Thompson                                    const int dim, const int nelem,
67288c0443SJeremy L Thompson                                    const CeedScalar *d_U, CeedScalar *slice) {
68c532df63SYohann   for (int i = 0; i < Q1D; i++)
694d537eeaSYohann     slice[i + tidz*Q1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
704d537eeaSYohann                             dim*BASIS_NCOMP*nelem*Q1D];
71c532df63SYohann }
72c532df63SYohann 
73ab213215SJeremy L Thompson //------------------------------------------------------------------------------
74ab213215SJeremy L Thompson // Write quadrature point data
75ab213215SJeremy L Thompson //------------------------------------------------------------------------------
76c532df63SYohann inline __device__ void writeQuads1d(const int elem, const int tidx,
77c532df63SYohann                                     const int tidy, const int comp,
78288c0443SJeremy L Thompson                                     const int dim, const int nelem,
79288c0443SJeremy L Thompson                                     const CeedScalar &r_V, CeedScalar *d_V) {
80c532df63SYohann   d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
81c532df63SYohann }
82c532df63SYohann 
83ab213215SJeremy L Thompson //------------------------------------------------------------------------------
84ab213215SJeremy L Thompson // 1D tensor contraction
85ab213215SJeremy L Thompson //------------------------------------------------------------------------------
86c532df63SYohann inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
87d94769d2SYohann Dudouit                                    const int tidy, const int tidz,
88288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
89288c0443SJeremy L Thompson                                    CeedScalar &V) {
90c532df63SYohann   V = 0.0;
91ab213215SJeremy L Thompson   for (int i = 0; i < P1D; ++i)
92ab213215SJeremy L Thompson     V += B[i + tidx*P1D] * slice[i + tidz*Q1D]; // Contract x direction
93c532df63SYohann }
94c532df63SYohann 
95ab213215SJeremy L Thompson //------------------------------------------------------------------------------
96ab213215SJeremy L Thompson // 1D transpose tensor contraction
97ab213215SJeremy L Thompson //------------------------------------------------------------------------------
98c532df63SYohann inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
99d94769d2SYohann Dudouit     const int tidy, const int tidz,
100c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
101c532df63SYohann   V = 0.0;
102ab213215SJeremy L Thompson   for (int i = 0; i < Q1D; ++i)
103ab213215SJeremy L Thompson     V += B[tidx + i*P1D] * slice[i + tidz*Q1D]; // Contract x direction
104c532df63SYohann }
105c532df63SYohann 
106ab213215SJeremy L Thompson //------------------------------------------------------------------------------
107ab213215SJeremy L Thompson // 1D interpolate to quadrature points
108ab213215SJeremy L Thompson //------------------------------------------------------------------------------
109c532df63SYohann inline __device__ void interp1d(const CeedInt nelem, const int transpose,
110288c0443SJeremy L Thompson                                 const CeedScalar *c_B,
111288c0443SJeremy L Thompson                                 const CeedScalar *__restrict__ d_U,
112c532df63SYohann                                 CeedScalar *__restrict__ d_V,
113c532df63SYohann                                 CeedScalar *slice) {
114c532df63SYohann   CeedScalar r_V;
115c532df63SYohann   CeedScalar r_t;
116c532df63SYohann 
117c532df63SYohann   const int tidx = threadIdx.x;
118c532df63SYohann   const int tidy = threadIdx.y;
119d94769d2SYohann Dudouit   const int tidz = threadIdx.z;
120c532df63SYohann 
121c532df63SYohann 
122c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
123c532df63SYohann        elem += gridDim.x*blockDim.z) {
124c532df63SYohann     for (int comp = 0; comp < BASIS_NCOMP; comp++) {
125c532df63SYohann       if (!transpose) {
126d94769d2SYohann Dudouit         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
127d94769d2SYohann Dudouit         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
128c532df63SYohann         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
129c532df63SYohann       } else {
130d94769d2SYohann Dudouit         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
131d94769d2SYohann Dudouit         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
132c532df63SYohann         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
133c532df63SYohann       }
134c532df63SYohann     }
135c532df63SYohann   }
136c532df63SYohann }
137c532df63SYohann 
138ab213215SJeremy L Thompson //------------------------------------------------------------------------------
139ab213215SJeremy L Thompson // 1D derivatives at quadrature points
140ab213215SJeremy L Thompson //------------------------------------------------------------------------------
141c532df63SYohann inline __device__ void grad1d(const CeedInt nelem, const int transpose,
142c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
143288c0443SJeremy L Thompson                               const CeedScalar *__restrict__ d_U,
144288c0443SJeremy L Thompson                               CeedScalar *__restrict__ d_V,
145c532df63SYohann                               CeedScalar *slice) {
146c532df63SYohann   CeedScalar r_U;
147c532df63SYohann   CeedScalar r_V;
148c532df63SYohann 
149c532df63SYohann   const int tidx = threadIdx.x;
150d94769d2SYohann Dudouit   const int tidy = threadIdx.y;
151d94769d2SYohann Dudouit   const int tidz = threadIdx.z;
152c532df63SYohann   int dim;
153c532df63SYohann 
154c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
155c532df63SYohann        elem += gridDim.x*blockDim.z) {
156c532df63SYohann     for(int comp = 0; comp < BASIS_NCOMP; comp++) {
157c532df63SYohann       if (!transpose) {
158d94769d2SYohann Dudouit         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
159d94769d2SYohann Dudouit         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
160c532df63SYohann         dim = 0;
161c532df63SYohann         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
162c532df63SYohann       } else {
163c532df63SYohann         dim = 0;
164d94769d2SYohann Dudouit         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
165d94769d2SYohann Dudouit         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
166c532df63SYohann         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
167c532df63SYohann       }
168c532df63SYohann     }
169c532df63SYohann   }
170c532df63SYohann }
171c532df63SYohann 
172ab213215SJeremy L Thompson //------------------------------------------------------------------------------
173ab213215SJeremy L Thompson // 1D Quadrature weights
174ab213215SJeremy L Thompson //------------------------------------------------------------------------------
175ab213215SJeremy L Thompson __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
176ab213215SJeremy L Thompson                          CeedScalar *w) {
177ab213215SJeremy L Thompson   const int tid = threadIdx.x;
178ab213215SJeremy L Thompson   const CeedScalar weight = qweight1d[tid];
179ab213215SJeremy L Thompson   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
180ab213215SJeremy L Thompson        elem += gridDim.x*blockDim.y) {
181ab213215SJeremy L Thompson     const int ind = elem*Q1D + tid;
182ab213215SJeremy L Thompson     w[ind] = weight;
183ab213215SJeremy L Thompson   }
184ab213215SJeremy L Thompson }
185ab213215SJeremy L Thompson 
186ab213215SJeremy L Thompson //------------------------------------------------------------------------------
187ab213215SJeremy L Thompson // 2D
188ab213215SJeremy L Thompson //------------------------------------------------------------------------------
189ab213215SJeremy L Thompson 
190ab213215SJeremy L Thompson //------------------------------------------------------------------------------
191ab213215SJeremy L Thompson // Read DoFs
192ab213215SJeremy L Thompson //------------------------------------------------------------------------------
193c532df63SYohann inline __device__ void readDofs2d(const int elem, const int tidx,
194c532df63SYohann                                   const int tidy, const int comp,
195288c0443SJeremy L Thompson                                   const int nelem, const CeedScalar *d_U,
196288c0443SJeremy L Thompson                                   CeedScalar &U) {
19749fd234cSJeremy L Thompson   U = (tidx<P1D && tidy<P1D) ?
19849fd234cSJeremy L Thompson       d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0;
199c532df63SYohann }
200c532df63SYohann 
201ab213215SJeremy L Thompson //------------------------------------------------------------------------------
202ab213215SJeremy L Thompson // Write DoFs
203ab213215SJeremy L Thompson //------------------------------------------------------------------------------
204c532df63SYohann inline __device__ void writeDofs2d(const int elem, const int tidx,
205c532df63SYohann                                    const int tidy, const int comp,
206288c0443SJeremy L Thompson                                    const int nelem, const CeedScalar &r_V,
207288c0443SJeremy L Thompson                                    CeedScalar *d_V) {
208ab213215SJeremy L Thompson   if (tidx<P1D && tidy<P1D)
20949fd234cSJeremy L Thompson     d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V;
210c532df63SYohann }
211c532df63SYohann 
212ab213215SJeremy L Thompson //------------------------------------------------------------------------------
213ab213215SJeremy L Thompson // Read quadrature point data
214ab213215SJeremy L Thompson //------------------------------------------------------------------------------
215c532df63SYohann inline __device__ void readQuads2d(const int elem, const int tidx,
216c532df63SYohann                                    const int tidy, const int comp,
217288c0443SJeremy L Thompson                                    const int dim, const int nelem,
218288c0443SJeremy L Thompson                                    const CeedScalar *d_U, CeedScalar &U ) {
219c532df63SYohann   U = d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
220c532df63SYohann                dim*BASIS_NCOMP*nelem*Q1D*Q1D];
221c532df63SYohann }
222c532df63SYohann 
223ab213215SJeremy L Thompson //------------------------------------------------------------------------------
224ab213215SJeremy L Thompson // Write quadrature point data
225ab213215SJeremy L Thompson //------------------------------------------------------------------------------
226c532df63SYohann inline __device__ void writeQuads2d(const int elem, const int tidx,
227c532df63SYohann                                     const int tidy, const int comp,
228288c0443SJeremy L Thompson                                     const int dim, const int nelem,
229288c0443SJeremy L Thompson                                     const CeedScalar &r_V, CeedScalar *d_V) {
230c532df63SYohann   d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
231c532df63SYohann            dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V;
232c532df63SYohann }
233c532df63SYohann 
234ab213215SJeremy L Thompson //------------------------------------------------------------------------------
235ab213215SJeremy L Thompson // 2D tensor contraction x
236ab213215SJeremy L Thompson //------------------------------------------------------------------------------
237c532df63SYohann inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
2384247ecf3SYohann Dudouit                                    const int tidy, const int tidz,
239288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
240288c0443SJeremy L Thompson                                    CeedScalar &V) {
2414247ecf3SYohann Dudouit   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
242c532df63SYohann   __syncthreads();
243c532df63SYohann   V = 0.0;
244ab213215SJeremy L Thompson   for (int i = 0; i < P1D; ++i)
245ab213215SJeremy L Thompson     V += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
246c532df63SYohann   __syncthreads();
247c532df63SYohann }
248c532df63SYohann 
249ab213215SJeremy L Thompson //------------------------------------------------------------------------------
250ab213215SJeremy L Thompson // 2D tensor contraction y
251ab213215SJeremy L Thompson //------------------------------------------------------------------------------
252c532df63SYohann inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
2534247ecf3SYohann Dudouit                                    const int tidy, const int tidz,
254288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
255288c0443SJeremy L Thompson                                    CeedScalar &V) {
2564247ecf3SYohann Dudouit   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
257c532df63SYohann   __syncthreads();
258c532df63SYohann   V = 0.0;
259ab213215SJeremy L Thompson   for (int i = 0; i < P1D; ++i)
260ab213215SJeremy L Thompson     V += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
261c532df63SYohann   __syncthreads();
262c532df63SYohann }
263c532df63SYohann 
264ab213215SJeremy L Thompson //------------------------------------------------------------------------------
265ab213215SJeremy L Thompson // 2D transpose tensor contraction y
266ab213215SJeremy L Thompson //------------------------------------------------------------------------------
267c532df63SYohann inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
2684247ecf3SYohann Dudouit     const int tidy, const int tidz,
269c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
2704247ecf3SYohann Dudouit   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
271c532df63SYohann   __syncthreads();
272c532df63SYohann   V = 0.0;
273ab213215SJeremy L Thompson   if (tidy < P1D)
274ab213215SJeremy L Thompson     for (int i = 0; i < Q1D; ++i)
275ab213215SJeremy L Thompson       V += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
276c532df63SYohann   __syncthreads();
277c532df63SYohann }
278c532df63SYohann 
279ab213215SJeremy L Thompson //------------------------------------------------------------------------------
280ab213215SJeremy L Thompson // 2D transpose tensor contraction x
281ab213215SJeremy L Thompson //------------------------------------------------------------------------------
282c532df63SYohann inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
2834247ecf3SYohann Dudouit     const int tidy, const int tidz,
284c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
2854247ecf3SYohann Dudouit   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
286c532df63SYohann   __syncthreads();
287c532df63SYohann   V = 0.0;
288ab213215SJeremy L Thompson   if (tidx < P1D)
289ab213215SJeremy L Thompson     for (int i = 0; i < Q1D; ++i)
290ab213215SJeremy L Thompson       V += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
291c532df63SYohann   __syncthreads();
292c532df63SYohann }
293c532df63SYohann 
294ab213215SJeremy L Thompson //------------------------------------------------------------------------------
295ab213215SJeremy L Thompson // 2D interpolate to quadrature points
296ab213215SJeremy L Thompson //------------------------------------------------------------------------------
297c532df63SYohann inline __device__ void interp2d(const CeedInt nelem, const int transpose,
298288c0443SJeremy L Thompson                                 const CeedScalar *c_B,
299288c0443SJeremy L Thompson                                 const CeedScalar *__restrict__ d_U,
300c532df63SYohann                                 CeedScalar *__restrict__ d_V,
301c532df63SYohann                                 CeedScalar *slice) {
302c532df63SYohann   CeedScalar r_V;
303c532df63SYohann   CeedScalar r_t;
304c532df63SYohann 
305c532df63SYohann   const int tidx = threadIdx.x;
306c532df63SYohann   const int tidy = threadIdx.y;
3074247ecf3SYohann Dudouit   const int tidz = threadIdx.z;
3084247ecf3SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
3094247ecf3SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
3104247ecf3SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
311c532df63SYohann 
3124247ecf3SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
3134247ecf3SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
3144247ecf3SYohann Dudouit     const int comp = tidz%BASIS_NCOMP;
315c532df63SYohann     r_V = 0.0;
316c532df63SYohann     r_t = 0.0;
317c532df63SYohann     if (!transpose) {
318c532df63SYohann       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
3194247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
3204247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
321c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
322c532df63SYohann     } else {
323c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
3244247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
3254247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
326c532df63SYohann       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
327c532df63SYohann     }
328c532df63SYohann   }
329c532df63SYohann }
330c532df63SYohann 
331ab213215SJeremy L Thompson //------------------------------------------------------------------------------
332ab213215SJeremy L Thompson // 2D derivatives at quadrature points
333ab213215SJeremy L Thompson //------------------------------------------------------------------------------
334c532df63SYohann inline __device__ void grad2d(const CeedInt nelem, const int transpose,
335c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
3367f823360Sjeremylt                               const CeedScalar *__restrict__ d_U,
3377f823360Sjeremylt                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
338c532df63SYohann   CeedScalar r_U;
339c532df63SYohann   CeedScalar r_V;
340c532df63SYohann   CeedScalar r_t;
341c532df63SYohann 
342c532df63SYohann   const int tidx = threadIdx.x;
343c532df63SYohann   const int tidy = threadIdx.y;
3444247ecf3SYohann Dudouit   const int tidz = threadIdx.z;
3454247ecf3SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
3464247ecf3SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
3474247ecf3SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
348c532df63SYohann   int dim;
349c532df63SYohann 
3504247ecf3SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
3514247ecf3SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
352c532df63SYohann     if (!transpose) {
353c532df63SYohann       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
3544247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
3554247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
356c532df63SYohann       dim = 0;
357c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
3584247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
3594247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
360c532df63SYohann       dim = 1;
361c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
362c532df63SYohann     } else {
363c532df63SYohann       dim = 0;
364c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
3654247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
3664247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
367c532df63SYohann       dim = 1;
368c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
3694247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
3704247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
371c532df63SYohann       r_V += r_U;
372c532df63SYohann       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
373c532df63SYohann     }
374c532df63SYohann   }
375c532df63SYohann }
376c532df63SYohann 
377ab213215SJeremy L Thompson //------------------------------------------------------------------------------
378ab213215SJeremy L Thompson // 2D quadrature weights
379ab213215SJeremy L Thompson //------------------------------------------------------------------------------
380ab213215SJeremy L Thompson __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
381ab213215SJeremy L Thompson                          CeedScalar *w) {
382ab213215SJeremy L Thompson   const int i = threadIdx.x;
383ab213215SJeremy L Thompson   const int j = threadIdx.y;
384ab213215SJeremy L Thompson   const CeedScalar weight = qweight1d[i]*qweight1d[j];
385ab213215SJeremy L Thompson   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
386ab213215SJeremy L Thompson        elem += gridDim.x*blockDim.z) {
387ab213215SJeremy L Thompson     const int ind = elem*Q1D*Q1D + i + j*Q1D;
388ab213215SJeremy L Thompson     w[ind] = weight;
389ab213215SJeremy L Thompson   }
390ab213215SJeremy L Thompson }
391ab213215SJeremy L Thompson 
392ab213215SJeremy L Thompson //------------------------------------------------------------------------------
393ab213215SJeremy L Thompson // 3D
394ab213215SJeremy L Thompson //------------------------------------------------------------------------------
395ab213215SJeremy L Thompson 
396ab213215SJeremy L Thompson //------------------------------------------------------------------------------
397ab213215SJeremy L Thompson // Read DoFs
398ab213215SJeremy L Thompson //------------------------------------------------------------------------------
399c532df63SYohann inline __device__ void readDofs3d(const int elem, const int tidx,
400c532df63SYohann                                   const int tidy, const int comp,
4017f823360Sjeremylt                                   const int nelem, const CeedScalar *d_U,
4027f823360Sjeremylt                                   CeedScalar *r_U) {
403c532df63SYohann   for (int i = 0; i < P1D; i++)
404ab213215SJeremy L Thompson     r_U[i] = (tidx < P1D && tidy < P1D) ?
40549fd234cSJeremy L Thompson               d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
40649fd234cSJeremy L Thompson                   comp*P1D*P1D*P1D*nelem] : 0.0;
407c532df63SYohann   for (int i = P1D; i < Q1D; i++)
408c532df63SYohann     r_U[i] = 0.0;
409c532df63SYohann }
410c532df63SYohann 
411ab213215SJeremy L Thompson //------------------------------------------------------------------------------
41249fd234cSJeremy L Thompson // Write DoFs
41349fd234cSJeremy L Thompson //------------------------------------------------------------------------------
41449fd234cSJeremy L Thompson inline __device__ void writeDofs3d(const int elem, const int tidx,
41549fd234cSJeremy L Thompson                                    const int tidy, const int comp,
41649fd234cSJeremy L Thompson                                    const int nelem, const CeedScalar *r_V,
41749fd234cSJeremy L Thompson                                    CeedScalar *d_V) {
41849fd234cSJeremy L Thompson   if (tidx < P1D && tidy < P1D) {
41949fd234cSJeremy L Thompson     for (int i = 0; i < P1D; i++)
42049fd234cSJeremy L Thompson       d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
42149fd234cSJeremy L Thompson           comp*P1D*P1D*P1D*nelem] = r_V[i];
42249fd234cSJeremy L Thompson   }
42349fd234cSJeremy L Thompson }
42449fd234cSJeremy L Thompson 
42549fd234cSJeremy L Thompson //------------------------------------------------------------------------------
426ab213215SJeremy L Thompson // Read quadrature point data
427ab213215SJeremy L Thompson //------------------------------------------------------------------------------
428c532df63SYohann inline __device__ void readQuads3d(const int elem, const int tidx,
429c532df63SYohann                                    const int tidy, const int comp,
4307f823360Sjeremylt                                    const int dim, const int nelem,
4317f823360Sjeremylt                                    const CeedScalar *d_U, CeedScalar *r_U) {
432c532df63SYohann   for (int i = 0; i < Q1D; i++)
433c532df63SYohann     r_U[i] = d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
434c532df63SYohann                  comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D];
435c532df63SYohann }
436c532df63SYohann 
437ab213215SJeremy L Thompson //------------------------------------------------------------------------------
438ab213215SJeremy L Thompson // Write quadrature point data
439ab213215SJeremy L Thompson //------------------------------------------------------------------------------
440c532df63SYohann inline __device__ void writeQuads3d(const int elem, const int tidx,
441c532df63SYohann                                     const int tidy, const int comp,
4427f823360Sjeremylt                                     const int dim, const int nelem,
4437f823360Sjeremylt                                     const CeedScalar *r_V, CeedScalar *d_V) {
444c532df63SYohann   for (int i = 0; i < Q1D; i++)
445c532df63SYohann     d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
446c532df63SYohann         dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i];
447c532df63SYohann }
448c532df63SYohann 
449ab213215SJeremy L Thompson //------------------------------------------------------------------------------
450ab213215SJeremy L Thompson // 3D tensor contract x
451ab213215SJeremy L Thompson //------------------------------------------------------------------------------
452c532df63SYohann inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
453698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
4547f823360Sjeremylt                                    const CeedScalar *U, const CeedScalar *B,
4557f823360Sjeremylt                                    CeedScalar *V) {
456c532df63SYohann   for (int k = 0; k < P1D; ++k) {
457698ebc35SYohann Dudouit     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
458c532df63SYohann     __syncthreads();
459c532df63SYohann     V[k] = 0.0;
460ab213215SJeremy L Thompson     for (int i = 0; i < P1D; ++i)
461ab213215SJeremy L Thompson       V[k] += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
462c532df63SYohann     __syncthreads();
463c532df63SYohann   }
464c532df63SYohann }
465c532df63SYohann 
466ab213215SJeremy L Thompson //------------------------------------------------------------------------------
467ab213215SJeremy L Thompson // 3D tensor contract y
468ab213215SJeremy L Thompson //------------------------------------------------------------------------------
469c532df63SYohann inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
470698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
4717f823360Sjeremylt                                    const CeedScalar *U, const CeedScalar *B,
4727f823360Sjeremylt                                    CeedScalar *V) {
473c532df63SYohann   for (int k = 0; k < P1D; ++k) {
474698ebc35SYohann Dudouit     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
475c532df63SYohann     __syncthreads();
476c532df63SYohann     V[k] = 0.0;
477ab213215SJeremy L Thompson     for (int i = 0; i < P1D; ++i)
478ab213215SJeremy L Thompson       V[k] += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
479c532df63SYohann     __syncthreads();
480c532df63SYohann   }
481c532df63SYohann }
482c532df63SYohann 
483ab213215SJeremy L Thompson //------------------------------------------------------------------------------
484ab213215SJeremy L Thompson // 3D tensor contract z
485ab213215SJeremy L Thompson //------------------------------------------------------------------------------
486c532df63SYohann inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
487698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
4887f823360Sjeremylt                                    const CeedScalar *U, const CeedScalar *B,
4897f823360Sjeremylt                                    CeedScalar *V) {
490c532df63SYohann   for (int k = 0; k < Q1D; ++k) {
491c532df63SYohann     V[k] = 0.0;
492ab213215SJeremy L Thompson     for (int i = 0; i < P1D; ++i)
493ab213215SJeremy L Thompson       V[k] += B[i + k*P1D] * U[i]; // Contract z direction
494c532df63SYohann   }
495c532df63SYohann }
496c532df63SYohann 
497ab213215SJeremy L Thompson //------------------------------------------------------------------------------
498ab213215SJeremy L Thompson // 3D transpose tensor contract z
499ab213215SJeremy L Thompson //------------------------------------------------------------------------------
500c532df63SYohann inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
501698ebc35SYohann Dudouit     const int tidy, const int tidz,
502c532df63SYohann     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
503c532df63SYohann   for (int k = 0; k < Q1D; ++k) {
504c532df63SYohann     V[k] = 0.0;
505ab213215SJeremy L Thompson     if (k < P1D)
506ab213215SJeremy L Thompson       for (int i = 0; i < Q1D; ++i)
507ab213215SJeremy L Thompson         V[k] += B[k + i*P1D] * U[i]; // Contract z direction
508c532df63SYohann   }
509c532df63SYohann }
510c532df63SYohann 
511ab213215SJeremy L Thompson //------------------------------------------------------------------------------
512ab213215SJeremy L Thompson // 3D transpose tensor contract y
513ab213215SJeremy L Thompson //------------------------------------------------------------------------------
514c532df63SYohann inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
515698ebc35SYohann Dudouit     const int tidy, const int tidz,
516c532df63SYohann     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
517c532df63SYohann   for (int k = 0; k < P1D; ++k) {
518698ebc35SYohann Dudouit     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
519c532df63SYohann     __syncthreads();
520c532df63SYohann     V[k] = 0.0;
521ab213215SJeremy L Thompson     if (tidy < P1D)
522ab213215SJeremy L Thompson       for (int i = 0; i < Q1D; ++i)
523ab213215SJeremy L Thompson         V[k] += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
524c532df63SYohann     __syncthreads();
525c532df63SYohann   }
526c532df63SYohann }
527c532df63SYohann 
528ab213215SJeremy L Thompson //------------------------------------------------------------------------------
529ab213215SJeremy L Thompson // 3D transpose tensor contract x
530ab213215SJeremy L Thompson //------------------------------------------------------------------------------
531c532df63SYohann inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
532698ebc35SYohann Dudouit     const int tidy, const int tidz,
533c532df63SYohann     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
534c532df63SYohann   for (int k = 0; k < P1D; ++k) {
535698ebc35SYohann Dudouit     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
536c532df63SYohann     __syncthreads();
537c532df63SYohann     V[k] = 0.0;
538ab213215SJeremy L Thompson     if (tidx < P1D)
539ab213215SJeremy L Thompson       for (int i = 0; i < Q1D; ++i)
540ab213215SJeremy L Thompson         V[k] += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
541c532df63SYohann     __syncthreads();
542c532df63SYohann   }
543c532df63SYohann }
544c532df63SYohann 
545ab213215SJeremy L Thompson //------------------------------------------------------------------------------
546ab213215SJeremy L Thompson // 3D interpolate to quadrature points
547ab213215SJeremy L Thompson //------------------------------------------------------------------------------
548c532df63SYohann inline __device__ void interp3d(const CeedInt nelem, const int transpose,
5497f823360Sjeremylt                                 const CeedScalar *c_B,
5507f823360Sjeremylt                                 const CeedScalar *__restrict__ d_U,
551c532df63SYohann                                 CeedScalar *__restrict__ d_V,
552c532df63SYohann                                 CeedScalar *slice) {
553c532df63SYohann   CeedScalar r_V[Q1D];
554c532df63SYohann   CeedScalar r_t[Q1D];
555c532df63SYohann 
556c532df63SYohann   const int tidx = threadIdx.x;
557c532df63SYohann   const int tidy = threadIdx.y;
558698ebc35SYohann Dudouit   const int tidz = threadIdx.z;
559698ebc35SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
560698ebc35SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
561698ebc35SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
562c532df63SYohann 
563698ebc35SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
564698ebc35SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
565c532df63SYohann     for (int i = 0; i < Q1D; ++i) {
566c532df63SYohann       r_V[i] = 0.0;
567c532df63SYohann       r_t[i] = 0.0;
568c532df63SYohann     }
569c532df63SYohann     if (!transpose) {
570c532df63SYohann       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
571698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
572698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
573698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
574c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
575c532df63SYohann     } else {
576c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
577698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
578698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
579698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
580c532df63SYohann       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
581c532df63SYohann     }
582c532df63SYohann   }
583c532df63SYohann }
584c532df63SYohann 
585ab213215SJeremy L Thompson //------------------------------------------------------------------------------
586ab213215SJeremy L Thompson // 3D derivatives at quadrature points
587ab213215SJeremy L Thompson //------------------------------------------------------------------------------
588c532df63SYohann inline __device__ void grad3d(const CeedInt nelem, const int transpose,
589c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
5907f823360Sjeremylt                               const CeedScalar *__restrict__ d_U,
5917f823360Sjeremylt                               CeedScalar *__restrict__ d_V,
592c532df63SYohann                               CeedScalar *slice) {
593ab213215SJeremy L Thompson   // Use P1D for one of these
594c532df63SYohann   CeedScalar r_U[Q1D];
595c532df63SYohann   CeedScalar r_V[Q1D];
596c532df63SYohann   CeedScalar r_t[Q1D];
597c532df63SYohann 
598c532df63SYohann   const int tidx = threadIdx.x;
599c532df63SYohann   const int tidy = threadIdx.y;
600698ebc35SYohann Dudouit   const int tidz = threadIdx.z;
601698ebc35SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
602698ebc35SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
603698ebc35SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
604c532df63SYohann   int dim;
605c532df63SYohann 
606698ebc35SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
607698ebc35SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
608c532df63SYohann     if (!transpose) {
609c532df63SYohann       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
610698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
611698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
612698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
613c532df63SYohann       dim = 0;
614c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
615698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
616698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
617698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
618c532df63SYohann       dim = 1;
619c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
620698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
621698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
622698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
623c532df63SYohann       dim = 2;
624c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
625c532df63SYohann     } else {
626c532df63SYohann       dim = 0;
627c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
628698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
629698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
630698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
631c532df63SYohann       dim = 1;
632c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
633698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
634698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
635698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
636c532df63SYohann       add(r_V, r_t);
637c532df63SYohann       dim = 2;
638c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
639698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
640698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
641698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
642c532df63SYohann       add(r_V, r_t);
643c532df63SYohann       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
644c532df63SYohann     }
645c532df63SYohann   }
646c532df63SYohann }
647c532df63SYohann 
648ab213215SJeremy L Thompson //------------------------------------------------------------------------------
649ab213215SJeremy L Thompson // 3D quadrature weights
650ab213215SJeremy L Thompson //------------------------------------------------------------------------------
651ab213215SJeremy L Thompson __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
652ab213215SJeremy L Thompson                          CeedScalar *w) {
653ab213215SJeremy L Thompson   const int i = threadIdx.x;
654ab213215SJeremy L Thompson   const int j = threadIdx.y;
655ab213215SJeremy L Thompson   const int k = threadIdx.z;
656ab213215SJeremy L Thompson   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
657ab213215SJeremy L Thompson   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
658ab213215SJeremy L Thompson     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
659ab213215SJeremy L Thompson     w[ind] = weight;
660ab213215SJeremy L Thompson   }
661ab213215SJeremy L Thompson }
662ab213215SJeremy L Thompson 
663ab213215SJeremy L Thompson 
664ab213215SJeremy L Thompson //------------------------------------------------------------------------------
665ab213215SJeremy L Thompson // Basis kernels
666ab213215SJeremy L Thompson //------------------------------------------------------------------------------
667ab213215SJeremy L Thompson 
668ab213215SJeremy L Thompson //------------------------------------------------------------------------------
669ab213215SJeremy L Thompson // Interp kernel by dim
670ab213215SJeremy L Thompson //------------------------------------------------------------------------------
671c532df63SYohann extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
6727f823360Sjeremylt                                   const CeedScalar *c_B,
6737f823360Sjeremylt                                   const CeedScalar *__restrict__ d_U,
674c532df63SYohann                                   CeedScalar *__restrict__ d_V) {
675074be161SYohann Dudouit   extern __shared__ double slice[];
676c532df63SYohann   if (BASIS_DIM == 1) {
677c532df63SYohann     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
678c532df63SYohann   } else if (BASIS_DIM == 2) {
679c532df63SYohann     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
680c532df63SYohann   } else if (BASIS_DIM == 3) {
681c532df63SYohann     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
682c532df63SYohann   }
683c532df63SYohann }
684c532df63SYohann 
685ab213215SJeremy L Thompson //------------------------------------------------------------------------------
686ab213215SJeremy L Thompson // Grad kernel by dim
687ab213215SJeremy L Thompson //------------------------------------------------------------------------------
688c532df63SYohann extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
689c532df63SYohann                                 const CeedScalar *c_B, const CeedScalar *c_G,
6907f823360Sjeremylt                                 const CeedScalar *__restrict__ d_U,
6917f823360Sjeremylt                                 CeedScalar *__restrict__ d_V) {
692074be161SYohann Dudouit   extern __shared__ double slice[];
693c532df63SYohann   if (BASIS_DIM == 1) {
694c532df63SYohann     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
695c532df63SYohann   } else if (BASIS_DIM == 2) {
696c532df63SYohann     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
697c532df63SYohann   } else if (BASIS_DIM == 3) {
698c532df63SYohann     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
699c532df63SYohann   }
700c532df63SYohann }
701c532df63SYohann 
702ab213215SJeremy L Thompson //------------------------------------------------------------------------------
703ab213215SJeremy L Thompson // Weight kernels by dim
704ab213215SJeremy L Thompson //------------------------------------------------------------------------------
705c532df63SYohann extern "C" __global__ void weight(const CeedInt nelem,
7067f823360Sjeremylt                                   const CeedScalar *__restrict__ qweight1d,
7077f823360Sjeremylt                                   CeedScalar *__restrict__ v) {
708c532df63SYohann   if (BASIS_DIM == 1) {
709c532df63SYohann     weight1d(nelem, qweight1d, v);
710c532df63SYohann   } else if (BASIS_DIM == 2) {
711c532df63SYohann     weight2d(nelem, qweight1d, v);
712c532df63SYohann   } else if (BASIS_DIM == 3) {
713c532df63SYohann     weight3d(nelem, qweight1d, v);
714c532df63SYohann   }
715c532df63SYohann }
716c532df63SYohann 
717c532df63SYohann );
718cb0b5415Sjeremylt // *INDENT-ON*
719c532df63SYohann 
720ab213215SJeremy L Thompson //------------------------------------------------------------------------------
721ab213215SJeremy L Thompson // Device initalization
722ab213215SJeremy L Thompson //------------------------------------------------------------------------------
723c532df63SYohann int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
724c532df63SYohann                        CeedScalar **c_B);
725c532df63SYohann int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
7267f823360Sjeremylt                            CeedInt Q1d, CeedScalar **c_B_ptr,
7277f823360Sjeremylt                            CeedScalar **c_G_ptr);
728c532df63SYohann 
729ab213215SJeremy L Thompson //------------------------------------------------------------------------------
730ab213215SJeremy L Thompson // Apply basis
731ab213215SJeremy L Thompson //------------------------------------------------------------------------------
732c532df63SYohann int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem,
733c532df63SYohann                                      CeedTransposeMode tmode,
7347f823360Sjeremylt                                      CeedEvalMode emode, CeedVector u,
7357f823360Sjeremylt                                      CeedVector v) {
736c532df63SYohann   int ierr;
737c532df63SYohann   Ceed ceed;
738c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
739c532df63SYohann   Ceed_Cuda_shared *ceed_Cuda;
740c532df63SYohann   CeedGetData(ceed, (void *) &ceed_Cuda); CeedChk(ierr);
741c532df63SYohann   CeedBasis_Cuda_shared *data;
742c532df63SYohann   CeedBasisGetData(basis, (void *)&data); CeedChk(ierr);
743c532df63SYohann   const CeedInt transpose = tmode == CEED_TRANSPOSE;
7444247ecf3SYohann Dudouit   CeedInt dim, ncomp;
745074be161SYohann Dudouit   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
7464247ecf3SYohann Dudouit   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
747c532df63SYohann 
748ab213215SJeremy L Thompson   // Read vectors
749c532df63SYohann   const CeedScalar *d_u;
750c532df63SYohann   CeedScalar *d_v;
751c532df63SYohann   if (emode != CEED_EVAL_WEIGHT) {
752c532df63SYohann     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
753c532df63SYohann   }
754c532df63SYohann   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
755c532df63SYohann 
756ab213215SJeremy L Thompson   // Clear v for transpose mode
757c532df63SYohann   if (tmode == CEED_TRANSPOSE) {
758c532df63SYohann     CeedInt length;
759c532df63SYohann     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
760c532df63SYohann     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
761c532df63SYohann   }
762ab213215SJeremy L Thompson 
763ab213215SJeremy L Thompson   // Apply basis operation
764ab213215SJeremy L Thompson   switch (emode) {
765ab213215SJeremy L Thompson   case CEED_EVAL_INTERP: {
766c532df63SYohann     CeedInt P1d, Q1d;
767c532df63SYohann     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
768c532df63SYohann     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
769c532df63SYohann     ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
770c532df63SYohann     CeedChk(ierr);
771cb0b5415Sjeremylt     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
772ccf0fe6fSjeremylt                           &d_u, &d_v
773ccf0fe6fSjeremylt                          };
7744d537eeaSYohann     if (dim == 1) {
775d94769d2SYohann Dudouit       CeedInt elemsPerBlock = 32;
7764d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
7774d537eeaSYohann                                              ? 1 : 0 );
778d94769d2SYohann Dudouit       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
7794d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, 1,
7804d537eeaSYohann                                         elemsPerBlock, sharedMem,
781ab213215SJeremy L Thompson                                         interpargs); CeedChk(ierr);
782074be161SYohann Dudouit     } else if (dim == 2) {
7834247ecf3SYohann Dudouit       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
7840f70cdf6SJeremy L Thompson       // elemsPerBlock must be at least 1
785*d99fa3c5SJeremy L Thompson       CeedInt elemsPerBlock = CeedIntMax(Q1d < 7 ? optElems[Q1d]/ncomp : 1, 1);
7864d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
7874d537eeaSYohann                                              ? 1 : 0 );
7884247ecf3SYohann Dudouit       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
7894d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d,
7904d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
791ab213215SJeremy L Thompson                                         interpargs); CeedChk(ierr);
792074be161SYohann Dudouit     } else if (dim == 3) {
7933f63d318SYohann Dudouit       CeedInt elemsPerBlock = 1;
7944d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
7954d537eeaSYohann                                              ? 1 : 0 );
796698ebc35SYohann Dudouit       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
7974d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d,
7984d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
799ab213215SJeremy L Thompson                                         interpargs); CeedChk(ierr);
800074be161SYohann Dudouit     }
801ab213215SJeremy L Thompson   } break;
802ab213215SJeremy L Thompson   case CEED_EVAL_GRAD: {
803c532df63SYohann     CeedInt P1d, Q1d;
804c532df63SYohann     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
805c532df63SYohann     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
806c532df63SYohann     ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
807c532df63SYohann                                   Q1d, &data->c_B, &data->c_G);
808c532df63SYohann     CeedChk(ierr);
809cb0b5415Sjeremylt     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
810ccf0fe6fSjeremylt                         &data->c_G, &d_u, &d_v
811ccf0fe6fSjeremylt                        };
8124d537eeaSYohann     if (dim == 1) {
813d94769d2SYohann Dudouit       CeedInt elemsPerBlock = 32;
8144d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8154d537eeaSYohann                                              ? 1 : 0 );
816d94769d2SYohann Dudouit       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
817ab213215SJeremy L Thompson       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, 1,
818ab213215SJeremy L Thompson                                         elemsPerBlock, sharedMem, gradargs);
819c532df63SYohann       CeedChk(ierr);
820074be161SYohann Dudouit     } else if (dim == 2) {
8214247ecf3SYohann Dudouit       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
8220f70cdf6SJeremy L Thompson       // elemsPerBlock must be at least 1
823*d99fa3c5SJeremy L Thompson       CeedInt elemsPerBlock = CeedIntMax(Q1d < 7 ? optElems[Q1d]/ncomp : 1, 1);
8244d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8254d537eeaSYohann                                              ? 1 : 0 );
8264247ecf3SYohann Dudouit       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
8274d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d,
8284d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
829ab213215SJeremy L Thompson                                         gradargs); CeedChk(ierr);
830074be161SYohann Dudouit     } else if (dim == 3) {
8313f63d318SYohann Dudouit       CeedInt elemsPerBlock = 1;
8324d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8334d537eeaSYohann                                              ? 1 : 0 );
834698ebc35SYohann Dudouit       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
8354d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d,
8364d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
837ab213215SJeremy L Thompson                                         gradargs); CeedChk(ierr);
838074be161SYohann Dudouit     }
839ab213215SJeremy L Thompson   } break;
840ab213215SJeremy L Thompson   case CEED_EVAL_WEIGHT: {
841074be161SYohann Dudouit     CeedInt Q1d;
842074be161SYohann Dudouit     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
843c532df63SYohann     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
844074be161SYohann Dudouit     if (dim == 1) {
845074be161SYohann Dudouit       const CeedInt elemsPerBlock = 32/Q1d;
8464d537eeaSYohann       const CeedInt gridsize = nelem/elemsPerBlock + ( (
8474d537eeaSYohann                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
8487f823360Sjeremylt       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d,
8497f823360Sjeremylt                                   elemsPerBlock, 1, weightargs);
8501226057fSYohann Dudouit       CeedChk(ierr);
851074be161SYohann Dudouit     } else if (dim == 2) {
852717ff8a3SYohann Dudouit       const CeedInt optElems = 32/(Q1d*Q1d);
853717ff8a3SYohann Dudouit       const CeedInt elemsPerBlock = optElems>0?optElems:1;
8544d537eeaSYohann       const CeedInt gridsize = nelem/elemsPerBlock + ( (
8554d537eeaSYohann                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
8564d537eeaSYohann       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d,
8574d537eeaSYohann                                   elemsPerBlock, weightargs);
8581226057fSYohann Dudouit       CeedChk(ierr);
859074be161SYohann Dudouit     } else if (dim == 3) {
860074be161SYohann Dudouit       const CeedInt gridsize = nelem;
8614d537eeaSYohann       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
8624d537eeaSYohann                                   weightargs);
8631226057fSYohann Dudouit       CeedChk(ierr);
864074be161SYohann Dudouit     }
865ab213215SJeremy L Thompson   } break;
866ab213215SJeremy L Thompson   // LCOV_EXCL_START
867ab213215SJeremy L Thompson   // Evaluate the divergence to/from the quadrature points
868ab213215SJeremy L Thompson   case CEED_EVAL_DIV:
869ab213215SJeremy L Thompson     return CeedError(ceed, 1, "CEED_EVAL_DIV not supported");
870ab213215SJeremy L Thompson   // Evaluate the curl to/from the quadrature points
871ab213215SJeremy L Thompson   case CEED_EVAL_CURL:
872ab213215SJeremy L Thompson     return CeedError(ceed, 1, "CEED_EVAL_CURL not supported");
873ab213215SJeremy L Thompson   // Take no action, BasisApply should not have been called
874ab213215SJeremy L Thompson   case CEED_EVAL_NONE:
875ab213215SJeremy L Thompson     return CeedError(ceed, 1,
876ab213215SJeremy L Thompson                      "CEED_EVAL_NONE does not make sense in this context");
877ab213215SJeremy L Thompson     // LCOV_EXCL_STOP
878c532df63SYohann   }
879c532df63SYohann 
880ab213215SJeremy L Thompson   // Restore vectors
881c532df63SYohann   if (emode != CEED_EVAL_WEIGHT) {
882c532df63SYohann     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
883c532df63SYohann   }
884c532df63SYohann   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
885c532df63SYohann   return 0;
886c532df63SYohann }
887c532df63SYohann 
888ab213215SJeremy L Thompson //------------------------------------------------------------------------------
889ab213215SJeremy L Thompson // Destroy basis
890ab213215SJeremy L Thompson //------------------------------------------------------------------------------
891c532df63SYohann static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
892c532df63SYohann   int ierr;
893c532df63SYohann   Ceed ceed;
894c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
895c532df63SYohann 
896c532df63SYohann   CeedBasis_Cuda_shared *data;
897c532df63SYohann   ierr = CeedBasisGetData(basis, (void *) &data); CeedChk(ierr);
898c532df63SYohann 
899c532df63SYohann   CeedChk_Cu(ceed, cuModuleUnload(data->module));
900c532df63SYohann 
901c532df63SYohann   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr);
902c532df63SYohann   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr);
903c532df63SYohann   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr);
9041958eb7cSJeremy L Thompson   ierr = cudaFree(data->d_collograd1d); CeedChk_Cu(ceed, ierr);
905c532df63SYohann 
906c532df63SYohann   ierr = CeedFree(&data); CeedChk(ierr);
907c532df63SYohann 
908c532df63SYohann   return 0;
909c532df63SYohann }
910c532df63SYohann 
911ab213215SJeremy L Thompson //------------------------------------------------------------------------------
912ab213215SJeremy L Thompson // Create tensor basis
913ab213215SJeremy L Thompson //------------------------------------------------------------------------------
914c532df63SYohann int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
915c532df63SYohann                                         const CeedScalar *interp1d,
916c532df63SYohann                                         const CeedScalar *grad1d,
917c532df63SYohann                                         const CeedScalar *qref1d,
918c532df63SYohann                                         const CeedScalar *qweight1d,
919c532df63SYohann                                         CeedBasis basis) {
920c532df63SYohann   int ierr;
921c532df63SYohann   Ceed ceed;
922c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
9234d537eeaSYohann   if (Q1d<P1d) {
924e9f4dca0SJeremy L Thompson     // LCOV_EXCL_START
9251226057fSYohann Dudouit     return CeedError(ceed, 1, "Backend does not implement underintegrated basis.");
926e9f4dca0SJeremy L Thompson     // LCOV_EXCL_STOP
9271226057fSYohann Dudouit   }
928c532df63SYohann   CeedBasis_Cuda_shared *data;
929c532df63SYohann   ierr = CeedCalloc(1, &data); CeedChk(ierr);
930c532df63SYohann 
931ab213215SJeremy L Thompson   // Copy basis data to GPU
932c532df63SYohann   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
933c532df63SYohann   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr);
934c532df63SYohann   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
935c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
936c532df63SYohann 
937c532df63SYohann   const CeedInt iBytes = qBytes * P1d;
938c532df63SYohann   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr);
939c532df63SYohann   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
940c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
941c532df63SYohann 
942c532df63SYohann   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr);
943c532df63SYohann   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
944c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
945c532df63SYohann 
946ab213215SJeremy L Thompson   // Compute collocated gradient and copy to GPU
947ac421f39SYohann   data->d_collograd1d = NULL;
948ac421f39SYohann   if (dim == 3 && Q1d >= P1d) {
949ac421f39SYohann     CeedScalar *collograd1d;
950ac421f39SYohann     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChk(ierr);
951ac421f39SYohann     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChk(ierr);
952ac421f39SYohann     ierr = cudaMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
953ac421f39SYohann     CeedChk_Cu(ceed, ierr);
954ac421f39SYohann     ierr = cudaMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
955ac421f39SYohann                       cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
9561958eb7cSJeremy L Thompson     ierr = CeedFree(&collograd1d); CeedChk(ierr);
957ac421f39SYohann   }
958ac421f39SYohann 
959ab213215SJeremy L Thompson   // Compile basis kernels
960c532df63SYohann   CeedInt ncomp;
961c532df63SYohann   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
9624a6d4bbdSYohann Dudouit   ierr = CeedCompileCuda(ceed, kernelsShared, &data->module, 7,
963c532df63SYohann                          "Q1D", Q1d,
964c532df63SYohann                          "P1D", P1d,
965c532df63SYohann                          "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
966c532df63SYohann                              Q1d : P1d, dim),
967c532df63SYohann                          "BASIS_DIM", dim,
968c532df63SYohann                          "BASIS_NCOMP", ncomp,
969c532df63SYohann                          "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
970c532df63SYohann                          "BASIS_NQPT", CeedIntPow(Q1d, dim)
971c532df63SYohann                         ); CeedChk(ierr);
9724a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp);
973c532df63SYohann   CeedChk(ierr);
9744a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad);
975c532df63SYohann   CeedChk(ierr);
9764a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight);
977c532df63SYohann   CeedChk(ierr);
978c532df63SYohann 
979ab213215SJeremy L Thompson   ierr = CeedBasisSetData(basis, (void *)&data); CeedChk(ierr);
980ab213215SJeremy L Thompson 
981ab213215SJeremy L Thompson   // Register backend functions
982c532df63SYohann   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
983c532df63SYohann                                 CeedBasisApplyTensor_Cuda_shared);
984c532df63SYohann   CeedChk(ierr);
985c532df63SYohann   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
986ab213215SJeremy L Thompson                                 CeedBasisDestroy_Cuda_shared); CeedChk(ierr);
987c532df63SYohann   return 0;
988c532df63SYohann }
989ab213215SJeremy L Thompson //------------------------------------------------------------------------------
990