xref: /libCEED/rust/libceed-sys/c-src/backends/cuda-shared/ceed-cuda-shared-basis.c (revision 49fd234cd5a6b1faaf6ecfc267b22ac88f378a38)
1c532df63SYohann // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2c532df63SYohann // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3c532df63SYohann // All Rights reserved. See files LICENSE and NOTICE for details.
4c532df63SYohann //
5c532df63SYohann // This file is part of CEED, a collection of benchmarks, miniapps, software
6c532df63SYohann // libraries and APIs for efficient high-order finite element and spectral
7c532df63SYohann // element discretizations for exascale applications. For more information and
8c532df63SYohann // source code availability see http://github.com/ceed.
9c532df63SYohann //
10c532df63SYohann // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11c532df63SYohann // a collaborative effort of two U.S. Department of Energy organizations (Office
12c532df63SYohann // of Science and the National Nuclear Security Administration) responsible for
13c532df63SYohann // the planning and preparation of a capable exascale ecosystem, including
14c532df63SYohann // software, applications, hardware, advanced system engineering and early
15c532df63SYohann // testbed platforms, in support of the nation's exascale computing imperative.
16c532df63SYohann 
17c532df63SYohann #include <ceed-backend.h>
18c532df63SYohann #include <ceed.h>
19c532df63SYohann #include "ceed-cuda-shared.h"
20c532df63SYohann #include "../cuda/ceed-cuda.h"
21c532df63SYohann 
22ab213215SJeremy L Thompson //------------------------------------------------------------------------------
23ab213215SJeremy L Thompson // Shared mem kernels
24ab213215SJeremy L Thompson //------------------------------------------------------------------------------
25cb0b5415Sjeremylt // *INDENT-OFF*
26c532df63SYohann static const char *kernelsShared = QUOTE(
27c532df63SYohann 
28ab213215SJeremy L Thompson //------------------------------------------------------------------------------
29ab213215SJeremy L Thompson // Sum input into output
30ab213215SJeremy L Thompson //------------------------------------------------------------------------------
31c532df63SYohann inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
32c532df63SYohann   for (int i = 0; i < Q1D; i++)
33c532df63SYohann     r_V[i] += r_U[i];
34c532df63SYohann }
35c532df63SYohann 
36ab213215SJeremy L Thompson //------------------------------------------------------------------------------
37ab213215SJeremy L Thompson // 1D
38ab213215SJeremy L Thompson //------------------------------------------------------------------------------
39c532df63SYohann 
40ab213215SJeremy L Thompson //------------------------------------------------------------------------------
41ab213215SJeremy L Thompson // Read DoFs
42ab213215SJeremy L Thompson //------------------------------------------------------------------------------
43c532df63SYohann inline __device__ void readDofs1d(const int elem, const int tidx,
44d94769d2SYohann Dudouit                                   const int tidy, const int tidz,const int comp,
457f823360Sjeremylt                                   const int nelem, const CeedScalar *d_U,
467f823360Sjeremylt                                   CeedScalar *slice) {
47c532df63SYohann   for (int i = 0; i < P1D; i++)
48*49fd234cSJeremy L Thompson     slice[i + tidz*Q1D] = d_U[i + elem*P1D + comp*P1D*nelem];
49c532df63SYohann   for (int i = P1D; i < Q1D; i++)
50d94769d2SYohann Dudouit     slice[i + tidz*Q1D] = 0.0;
51c532df63SYohann }
52c532df63SYohann 
53ab213215SJeremy L Thompson //------------------------------------------------------------------------------
54ab213215SJeremy L Thompson // Write DoFs
55ab213215SJeremy L Thompson //------------------------------------------------------------------------------
56c532df63SYohann inline __device__ void writeDofs1d(const int elem, const int tidx,
57c532df63SYohann                                    const int tidy, const int comp,
58288c0443SJeremy L Thompson                                    const int nelem, const CeedScalar &r_V,
59288c0443SJeremy L Thompson                                    CeedScalar *d_V) {
60ab213215SJeremy L Thompson   if (tidx<P1D)
61*49fd234cSJeremy L Thompson     d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V;
62c532df63SYohann }
63c532df63SYohann 
64ab213215SJeremy L Thompson //------------------------------------------------------------------------------
65ab213215SJeremy L Thompson // Read quadrature point data
66ab213215SJeremy L Thompson //------------------------------------------------------------------------------
67c532df63SYohann inline __device__ void readQuads1d(const int elem, const int tidx,
68d94769d2SYohann Dudouit                                    const int tidy, const int tidz, const int comp,
69288c0443SJeremy L Thompson                                    const int dim, const int nelem,
70288c0443SJeremy L Thompson                                    const CeedScalar *d_U, CeedScalar *slice) {
71c532df63SYohann   for (int i = 0; i < Q1D; i++)
724d537eeaSYohann     slice[i + tidz*Q1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
734d537eeaSYohann                             dim*BASIS_NCOMP*nelem*Q1D];
74c532df63SYohann }
75c532df63SYohann 
76ab213215SJeremy L Thompson //------------------------------------------------------------------------------
77ab213215SJeremy L Thompson // Write quadrature point data
78ab213215SJeremy L Thompson //------------------------------------------------------------------------------
79c532df63SYohann inline __device__ void writeQuads1d(const int elem, const int tidx,
80c532df63SYohann                                     const int tidy, const int comp,
81288c0443SJeremy L Thompson                                     const int dim, const int nelem,
82288c0443SJeremy L Thompson                                     const CeedScalar &r_V, CeedScalar *d_V) {
83c532df63SYohann   d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
84c532df63SYohann }
85c532df63SYohann 
86ab213215SJeremy L Thompson //------------------------------------------------------------------------------
87ab213215SJeremy L Thompson // 1D tensor contraction
88ab213215SJeremy L Thompson //------------------------------------------------------------------------------
89c532df63SYohann inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
90d94769d2SYohann Dudouit                                    const int tidy, const int tidz,
91288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
92288c0443SJeremy L Thompson                                    CeedScalar &V) {
93c532df63SYohann   V = 0.0;
94ab213215SJeremy L Thompson   for (int i = 0; i < P1D; ++i)
95ab213215SJeremy L Thompson     V += B[i + tidx*P1D] * slice[i + tidz*Q1D]; // Contract x direction
96c532df63SYohann }
97c532df63SYohann 
98ab213215SJeremy L Thompson //------------------------------------------------------------------------------
99ab213215SJeremy L Thompson // 1D transpose tensor contraction
100ab213215SJeremy L Thompson //------------------------------------------------------------------------------
101c532df63SYohann inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
102d94769d2SYohann Dudouit     const int tidy, const int tidz,
103c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
104c532df63SYohann   V = 0.0;
105ab213215SJeremy L Thompson   for (int i = 0; i < Q1D; ++i)
106ab213215SJeremy L Thompson     V += B[tidx + i*P1D] * slice[i + tidz*Q1D]; // Contract x direction
107c532df63SYohann }
108c532df63SYohann 
109ab213215SJeremy L Thompson //------------------------------------------------------------------------------
110ab213215SJeremy L Thompson // 1D interpolate to quadrature points
111ab213215SJeremy L Thompson //------------------------------------------------------------------------------
112c532df63SYohann inline __device__ void interp1d(const CeedInt nelem, const int transpose,
113288c0443SJeremy L Thompson                                 const CeedScalar *c_B,
114288c0443SJeremy L Thompson                                 const CeedScalar *__restrict__ d_U,
115c532df63SYohann                                 CeedScalar *__restrict__ d_V,
116c532df63SYohann                                 CeedScalar *slice) {
117c532df63SYohann   CeedScalar r_V;
118c532df63SYohann   CeedScalar r_t;
119c532df63SYohann 
120c532df63SYohann   const int tidx = threadIdx.x;
121c532df63SYohann   const int tidy = threadIdx.y;
122d94769d2SYohann Dudouit   const int tidz = threadIdx.z;
123c532df63SYohann 
124c532df63SYohann 
125c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
126c532df63SYohann        elem += gridDim.x*blockDim.z) {
127c532df63SYohann     for (int comp = 0; comp < BASIS_NCOMP; comp++) {
128c532df63SYohann       if (!transpose) {
129d94769d2SYohann Dudouit         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
130d94769d2SYohann Dudouit         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
131c532df63SYohann         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
132c532df63SYohann       } else {
133d94769d2SYohann Dudouit         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
134d94769d2SYohann Dudouit         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
135c532df63SYohann         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
136c532df63SYohann       }
137c532df63SYohann     }
138c532df63SYohann   }
139c532df63SYohann }
140c532df63SYohann 
141ab213215SJeremy L Thompson //------------------------------------------------------------------------------
142ab213215SJeremy L Thompson // 1D derivatives at quadrature points
143ab213215SJeremy L Thompson //------------------------------------------------------------------------------
144c532df63SYohann inline __device__ void grad1d(const CeedInt nelem, const int transpose,
145c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
146288c0443SJeremy L Thompson                               const CeedScalar *__restrict__ d_U,
147288c0443SJeremy L Thompson                               CeedScalar *__restrict__ d_V,
148c532df63SYohann                               CeedScalar *slice) {
149c532df63SYohann   CeedScalar r_U;
150c532df63SYohann   CeedScalar r_V;
151c532df63SYohann 
152c532df63SYohann   const int tidx = threadIdx.x;
153d94769d2SYohann Dudouit   const int tidy = threadIdx.y;
154d94769d2SYohann Dudouit   const int tidz = threadIdx.z;
155c532df63SYohann   int dim;
156c532df63SYohann 
157c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
158c532df63SYohann        elem += gridDim.x*blockDim.z) {
159c532df63SYohann     for(int comp = 0; comp < BASIS_NCOMP; comp++) {
160c532df63SYohann       if (!transpose) {
161d94769d2SYohann Dudouit         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
162d94769d2SYohann Dudouit         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
163c532df63SYohann         dim = 0;
164c532df63SYohann         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
165c532df63SYohann       } else {
166c532df63SYohann         dim = 0;
167d94769d2SYohann Dudouit         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
168d94769d2SYohann Dudouit         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
169c532df63SYohann         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
170c532df63SYohann       }
171c532df63SYohann     }
172c532df63SYohann   }
173c532df63SYohann }
174c532df63SYohann 
175ab213215SJeremy L Thompson //------------------------------------------------------------------------------
176ab213215SJeremy L Thompson // 1D Quadrature weights
177ab213215SJeremy L Thompson //------------------------------------------------------------------------------
178ab213215SJeremy L Thompson __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
179ab213215SJeremy L Thompson                          CeedScalar *w) {
180ab213215SJeremy L Thompson   const int tid = threadIdx.x;
181ab213215SJeremy L Thompson   const CeedScalar weight = qweight1d[tid];
182ab213215SJeremy L Thompson   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
183ab213215SJeremy L Thompson        elem += gridDim.x*blockDim.y) {
184ab213215SJeremy L Thompson     const int ind = elem*Q1D + tid;
185ab213215SJeremy L Thompson     w[ind] = weight;
186ab213215SJeremy L Thompson   }
187ab213215SJeremy L Thompson }
188ab213215SJeremy L Thompson 
189ab213215SJeremy L Thompson //------------------------------------------------------------------------------
190ab213215SJeremy L Thompson // 2D
191ab213215SJeremy L Thompson //------------------------------------------------------------------------------
192ab213215SJeremy L Thompson 
193ab213215SJeremy L Thompson //------------------------------------------------------------------------------
194ab213215SJeremy L Thompson // Read DoFs
195ab213215SJeremy L Thompson //------------------------------------------------------------------------------
196c532df63SYohann inline __device__ void readDofs2d(const int elem, const int tidx,
197c532df63SYohann                                   const int tidy, const int comp,
198288c0443SJeremy L Thompson                                   const int nelem, const CeedScalar *d_U,
199288c0443SJeremy L Thompson                                   CeedScalar &U) {
200*49fd234cSJeremy L Thompson   U = (tidx<P1D && tidy<P1D) ?
201*49fd234cSJeremy L Thompson       d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0;
202c532df63SYohann }
203c532df63SYohann 
204ab213215SJeremy L Thompson //------------------------------------------------------------------------------
205ab213215SJeremy L Thompson // Write DoFs
206ab213215SJeremy L Thompson //------------------------------------------------------------------------------
207c532df63SYohann inline __device__ void writeDofs2d(const int elem, const int tidx,
208c532df63SYohann                                    const int tidy, const int comp,
209288c0443SJeremy L Thompson                                    const int nelem, const CeedScalar &r_V,
210288c0443SJeremy L Thompson                                    CeedScalar *d_V) {
211ab213215SJeremy L Thompson   if (tidx<P1D && tidy<P1D)
212*49fd234cSJeremy L Thompson     d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V;
213c532df63SYohann }
214c532df63SYohann 
215ab213215SJeremy L Thompson //------------------------------------------------------------------------------
216ab213215SJeremy L Thompson // Read quadrature point data
217ab213215SJeremy L Thompson //------------------------------------------------------------------------------
218c532df63SYohann inline __device__ void readQuads2d(const int elem, const int tidx,
219c532df63SYohann                                    const int tidy, const int comp,
220288c0443SJeremy L Thompson                                    const int dim, const int nelem,
221288c0443SJeremy L Thompson                                    const CeedScalar *d_U, CeedScalar &U ) {
222c532df63SYohann   U = d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
223c532df63SYohann                dim*BASIS_NCOMP*nelem*Q1D*Q1D];
224c532df63SYohann }
225c532df63SYohann 
226ab213215SJeremy L Thompson //------------------------------------------------------------------------------
227ab213215SJeremy L Thompson // Write quadrature point data
228ab213215SJeremy L Thompson //------------------------------------------------------------------------------
229c532df63SYohann inline __device__ void writeQuads2d(const int elem, const int tidx,
230c532df63SYohann                                     const int tidy, const int comp,
231288c0443SJeremy L Thompson                                     const int dim, const int nelem,
232288c0443SJeremy L Thompson                                     const CeedScalar &r_V, CeedScalar *d_V) {
233c532df63SYohann   d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
234c532df63SYohann            dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V;
235c532df63SYohann }
236c532df63SYohann 
237ab213215SJeremy L Thompson //------------------------------------------------------------------------------
238ab213215SJeremy L Thompson // 2D tensor contraction x
239ab213215SJeremy L Thompson //------------------------------------------------------------------------------
240c532df63SYohann inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
2414247ecf3SYohann Dudouit                                    const int tidy, const int tidz,
242288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
243288c0443SJeremy L Thompson                                    CeedScalar &V) {
2444247ecf3SYohann Dudouit   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
245c532df63SYohann   __syncthreads();
246c532df63SYohann   V = 0.0;
247ab213215SJeremy L Thompson   for (int i = 0; i < P1D; ++i)
248ab213215SJeremy L Thompson     V += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
249c532df63SYohann   __syncthreads();
250c532df63SYohann }
251c532df63SYohann 
252ab213215SJeremy L Thompson //------------------------------------------------------------------------------
253ab213215SJeremy L Thompson // 2D tensor contraction y
254ab213215SJeremy L Thompson //------------------------------------------------------------------------------
255c532df63SYohann inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
2564247ecf3SYohann Dudouit                                    const int tidy, const int tidz,
257288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
258288c0443SJeremy L Thompson                                    CeedScalar &V) {
2594247ecf3SYohann Dudouit   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
260c532df63SYohann   __syncthreads();
261c532df63SYohann   V = 0.0;
262ab213215SJeremy L Thompson   for (int i = 0; i < P1D; ++i)
263ab213215SJeremy L Thompson     V += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
264c532df63SYohann   __syncthreads();
265c532df63SYohann }
266c532df63SYohann 
267ab213215SJeremy L Thompson //------------------------------------------------------------------------------
268ab213215SJeremy L Thompson // 2D transpose tensor contraction y
269ab213215SJeremy L Thompson //------------------------------------------------------------------------------
270c532df63SYohann inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
2714247ecf3SYohann Dudouit     const int tidy, const int tidz,
272c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
2734247ecf3SYohann Dudouit   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
274c532df63SYohann   __syncthreads();
275c532df63SYohann   V = 0.0;
276ab213215SJeremy L Thompson   if (tidy < P1D)
277ab213215SJeremy L Thompson     for (int i = 0; i < Q1D; ++i)
278ab213215SJeremy L Thompson       V += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
279c532df63SYohann   __syncthreads();
280c532df63SYohann }
281c532df63SYohann 
282ab213215SJeremy L Thompson //------------------------------------------------------------------------------
283ab213215SJeremy L Thompson // 2D transpose tensor contraction x
284ab213215SJeremy L Thompson //------------------------------------------------------------------------------
285c532df63SYohann inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
2864247ecf3SYohann Dudouit     const int tidy, const int tidz,
287c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
2884247ecf3SYohann Dudouit   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
289c532df63SYohann   __syncthreads();
290c532df63SYohann   V = 0.0;
291ab213215SJeremy L Thompson   if (tidx < P1D)
292ab213215SJeremy L Thompson     for (int i = 0; i < Q1D; ++i)
293ab213215SJeremy L Thompson       V += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
294c532df63SYohann   __syncthreads();
295c532df63SYohann }
296c532df63SYohann 
297ab213215SJeremy L Thompson //------------------------------------------------------------------------------
298ab213215SJeremy L Thompson // 2D interpolate to quadrature points
299ab213215SJeremy L Thompson //------------------------------------------------------------------------------
300c532df63SYohann inline __device__ void interp2d(const CeedInt nelem, const int transpose,
301288c0443SJeremy L Thompson                                 const CeedScalar *c_B,
302288c0443SJeremy L Thompson                                 const CeedScalar *__restrict__ d_U,
303c532df63SYohann                                 CeedScalar *__restrict__ d_V,
304c532df63SYohann                                 CeedScalar *slice) {
305c532df63SYohann   CeedScalar r_V;
306c532df63SYohann   CeedScalar r_t;
307c532df63SYohann 
308c532df63SYohann   const int tidx = threadIdx.x;
309c532df63SYohann   const int tidy = threadIdx.y;
3104247ecf3SYohann Dudouit   const int tidz = threadIdx.z;
3114247ecf3SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
3124247ecf3SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
3134247ecf3SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
314c532df63SYohann 
3154247ecf3SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
3164247ecf3SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
3174247ecf3SYohann Dudouit     const int comp = tidz%BASIS_NCOMP;
318c532df63SYohann     r_V = 0.0;
319c532df63SYohann     r_t = 0.0;
320c532df63SYohann     if (!transpose) {
321c532df63SYohann       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
3224247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
3234247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
324c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
325c532df63SYohann     } else {
326c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
3274247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
3284247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
329c532df63SYohann       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
330c532df63SYohann     }
331c532df63SYohann   }
332c532df63SYohann }
333c532df63SYohann 
334ab213215SJeremy L Thompson //------------------------------------------------------------------------------
335ab213215SJeremy L Thompson // 2D derivatives at quadrature points
336ab213215SJeremy L Thompson //------------------------------------------------------------------------------
337c532df63SYohann inline __device__ void grad2d(const CeedInt nelem, const int transpose,
338c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
3397f823360Sjeremylt                               const CeedScalar *__restrict__ d_U,
3407f823360Sjeremylt                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
341c532df63SYohann   CeedScalar r_U;
342c532df63SYohann   CeedScalar r_V;
343c532df63SYohann   CeedScalar r_t;
344c532df63SYohann 
345c532df63SYohann   const int tidx = threadIdx.x;
346c532df63SYohann   const int tidy = threadIdx.y;
3474247ecf3SYohann Dudouit   const int tidz = threadIdx.z;
3484247ecf3SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
3494247ecf3SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
3504247ecf3SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
351c532df63SYohann   int dim;
352c532df63SYohann 
3534247ecf3SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
3544247ecf3SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
355c532df63SYohann     if (!transpose) {
356c532df63SYohann       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
3574247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
3584247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
359c532df63SYohann       dim = 0;
360c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
3614247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
3624247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
363c532df63SYohann       dim = 1;
364c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
365c532df63SYohann     } else {
366c532df63SYohann       dim = 0;
367c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
3684247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
3694247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
370c532df63SYohann       dim = 1;
371c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
3724247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
3734247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
374c532df63SYohann       r_V += r_U;
375c532df63SYohann       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
376c532df63SYohann     }
377c532df63SYohann   }
378c532df63SYohann }
379c532df63SYohann 
380ab213215SJeremy L Thompson //------------------------------------------------------------------------------
381ab213215SJeremy L Thompson // 2D quadrature weights
382ab213215SJeremy L Thompson //------------------------------------------------------------------------------
383ab213215SJeremy L Thompson __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
384ab213215SJeremy L Thompson                          CeedScalar *w) {
385ab213215SJeremy L Thompson   const int i = threadIdx.x;
386ab213215SJeremy L Thompson   const int j = threadIdx.y;
387ab213215SJeremy L Thompson   const CeedScalar weight = qweight1d[i]*qweight1d[j];
388ab213215SJeremy L Thompson   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
389ab213215SJeremy L Thompson        elem += gridDim.x*blockDim.z) {
390ab213215SJeremy L Thompson     const int ind = elem*Q1D*Q1D + i + j*Q1D;
391ab213215SJeremy L Thompson     w[ind] = weight;
392ab213215SJeremy L Thompson   }
393ab213215SJeremy L Thompson }
394ab213215SJeremy L Thompson 
395ab213215SJeremy L Thompson //------------------------------------------------------------------------------
396ab213215SJeremy L Thompson // 3D
397ab213215SJeremy L Thompson //------------------------------------------------------------------------------
398ab213215SJeremy L Thompson 
399ab213215SJeremy L Thompson //------------------------------------------------------------------------------
400ab213215SJeremy L Thompson // Read DoFs
401ab213215SJeremy L Thompson //------------------------------------------------------------------------------
402c532df63SYohann inline __device__ void readDofs3d(const int elem, const int tidx,
403c532df63SYohann                                   const int tidy, const int comp,
4047f823360Sjeremylt                                   const int nelem, const CeedScalar *d_U,
4057f823360Sjeremylt                                   CeedScalar *r_U) {
406c532df63SYohann   for (int i = 0; i < P1D; i++)
407ab213215SJeremy L Thompson     r_U[i] = (tidx < P1D && tidy < P1D) ?
408*49fd234cSJeremy L Thompson               d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
409*49fd234cSJeremy L Thompson                   comp*P1D*P1D*P1D*nelem] : 0.0;
410c532df63SYohann   for (int i = P1D; i < Q1D; i++)
411c532df63SYohann     r_U[i] = 0.0;
412c532df63SYohann }
413c532df63SYohann 
414ab213215SJeremy L Thompson //------------------------------------------------------------------------------
415*49fd234cSJeremy L Thompson // Write DoFs
416*49fd234cSJeremy L Thompson //------------------------------------------------------------------------------
417*49fd234cSJeremy L Thompson inline __device__ void writeDofs3d(const int elem, const int tidx,
418*49fd234cSJeremy L Thompson                                    const int tidy, const int comp,
419*49fd234cSJeremy L Thompson                                    const int nelem, const CeedScalar *r_V,
420*49fd234cSJeremy L Thompson                                    CeedScalar *d_V) {
421*49fd234cSJeremy L Thompson   if (tidx < P1D && tidy < P1D) {
422*49fd234cSJeremy L Thompson     for (int i = 0; i < P1D; i++)
423*49fd234cSJeremy L Thompson       d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
424*49fd234cSJeremy L Thompson           comp*P1D*P1D*P1D*nelem] = r_V[i];
425*49fd234cSJeremy L Thompson   }
426*49fd234cSJeremy L Thompson }
427*49fd234cSJeremy L Thompson 
428*49fd234cSJeremy L Thompson //------------------------------------------------------------------------------
429ab213215SJeremy L Thompson // Read quadrature point data
430ab213215SJeremy L Thompson //------------------------------------------------------------------------------
431c532df63SYohann inline __device__ void readQuads3d(const int elem, const int tidx,
432c532df63SYohann                                    const int tidy, const int comp,
4337f823360Sjeremylt                                    const int dim, const int nelem,
4347f823360Sjeremylt                                    const CeedScalar *d_U, CeedScalar *r_U) {
435c532df63SYohann   for (int i = 0; i < Q1D; i++)
436c532df63SYohann     r_U[i] = d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
437c532df63SYohann                  comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D];
438c532df63SYohann }
439c532df63SYohann 
440ab213215SJeremy L Thompson //------------------------------------------------------------------------------
441ab213215SJeremy L Thompson // Write quadrature point data
442ab213215SJeremy L Thompson //------------------------------------------------------------------------------
443c532df63SYohann inline __device__ void writeQuads3d(const int elem, const int tidx,
444c532df63SYohann                                     const int tidy, const int comp,
4457f823360Sjeremylt                                     const int dim, const int nelem,
4467f823360Sjeremylt                                     const CeedScalar *r_V, CeedScalar *d_V) {
447c532df63SYohann   for (int i = 0; i < Q1D; i++)
448c532df63SYohann     d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
449c532df63SYohann         dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i];
450c532df63SYohann }
451c532df63SYohann 
452ab213215SJeremy L Thompson //------------------------------------------------------------------------------
453ab213215SJeremy L Thompson // 3D tensor contract x
454ab213215SJeremy L Thompson //------------------------------------------------------------------------------
455c532df63SYohann inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
456698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
4577f823360Sjeremylt                                    const CeedScalar *U, const CeedScalar *B,
4587f823360Sjeremylt                                    CeedScalar *V) {
459c532df63SYohann   for (int k = 0; k < P1D; ++k) {
460698ebc35SYohann Dudouit     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
461c532df63SYohann     __syncthreads();
462c532df63SYohann     V[k] = 0.0;
463ab213215SJeremy L Thompson     for (int i = 0; i < P1D; ++i)
464ab213215SJeremy L Thompson       V[k] += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
465c532df63SYohann     __syncthreads();
466c532df63SYohann   }
467c532df63SYohann }
468c532df63SYohann 
469ab213215SJeremy L Thompson //------------------------------------------------------------------------------
470ab213215SJeremy L Thompson // 3D tensor contract y
471ab213215SJeremy L Thompson //------------------------------------------------------------------------------
472c532df63SYohann inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
473698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
4747f823360Sjeremylt                                    const CeedScalar *U, const CeedScalar *B,
4757f823360Sjeremylt                                    CeedScalar *V) {
476c532df63SYohann   for (int k = 0; k < P1D; ++k) {
477698ebc35SYohann Dudouit     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
478c532df63SYohann     __syncthreads();
479c532df63SYohann     V[k] = 0.0;
480ab213215SJeremy L Thompson     for (int i = 0; i < P1D; ++i)
481ab213215SJeremy L Thompson       V[k] += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
482c532df63SYohann     __syncthreads();
483c532df63SYohann   }
484c532df63SYohann }
485c532df63SYohann 
486ab213215SJeremy L Thompson //------------------------------------------------------------------------------
487ab213215SJeremy L Thompson // 3D tensor contract z
488ab213215SJeremy L Thompson //------------------------------------------------------------------------------
489c532df63SYohann inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
490698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
4917f823360Sjeremylt                                    const CeedScalar *U, const CeedScalar *B,
4927f823360Sjeremylt                                    CeedScalar *V) {
493c532df63SYohann   for (int k = 0; k < Q1D; ++k) {
494c532df63SYohann     V[k] = 0.0;
495ab213215SJeremy L Thompson     for (int i = 0; i < P1D; ++i)
496ab213215SJeremy L Thompson       V[k] += B[i + k*P1D] * U[i]; // Contract z direction
497c532df63SYohann   }
498c532df63SYohann }
499c532df63SYohann 
500ab213215SJeremy L Thompson //------------------------------------------------------------------------------
501ab213215SJeremy L Thompson // 3D transpose tensor contract z
502ab213215SJeremy L Thompson //------------------------------------------------------------------------------
503c532df63SYohann inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
504698ebc35SYohann Dudouit     const int tidy, const int tidz,
505c532df63SYohann     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
506c532df63SYohann   for (int k = 0; k < Q1D; ++k) {
507c532df63SYohann     V[k] = 0.0;
508ab213215SJeremy L Thompson     if (k < P1D)
509ab213215SJeremy L Thompson       for (int i = 0; i < Q1D; ++i)
510ab213215SJeremy L Thompson         V[k] += B[k + i*P1D] * U[i]; // Contract z direction
511c532df63SYohann   }
512c532df63SYohann }
513c532df63SYohann 
514ab213215SJeremy L Thompson //------------------------------------------------------------------------------
515ab213215SJeremy L Thompson // 3D transpose tensor contract y
516ab213215SJeremy L Thompson //------------------------------------------------------------------------------
517c532df63SYohann inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
518698ebc35SYohann Dudouit     const int tidy, const int tidz,
519c532df63SYohann     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
520c532df63SYohann   for (int k = 0; k < P1D; ++k) {
521698ebc35SYohann Dudouit     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
522c532df63SYohann     __syncthreads();
523c532df63SYohann     V[k] = 0.0;
524ab213215SJeremy L Thompson     if (tidy < P1D)
525ab213215SJeremy L Thompson       for (int i = 0; i < Q1D; ++i)
526ab213215SJeremy L Thompson         V[k] += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
527c532df63SYohann     __syncthreads();
528c532df63SYohann   }
529c532df63SYohann }
530c532df63SYohann 
531ab213215SJeremy L Thompson //------------------------------------------------------------------------------
532ab213215SJeremy L Thompson // 3D transpose tensor contract x
533ab213215SJeremy L Thompson //------------------------------------------------------------------------------
534c532df63SYohann inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
535698ebc35SYohann Dudouit     const int tidy, const int tidz,
536c532df63SYohann     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
537c532df63SYohann   for (int k = 0; k < P1D; ++k) {
538698ebc35SYohann Dudouit     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
539c532df63SYohann     __syncthreads();
540c532df63SYohann     V[k] = 0.0;
541ab213215SJeremy L Thompson     if (tidx < P1D)
542ab213215SJeremy L Thompson       for (int i = 0; i < Q1D; ++i)
543ab213215SJeremy L Thompson         V[k] += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
544c532df63SYohann     __syncthreads();
545c532df63SYohann   }
546c532df63SYohann }
547c532df63SYohann 
548ab213215SJeremy L Thompson //------------------------------------------------------------------------------
549ab213215SJeremy L Thompson // 3D interpolate to quadrature points
550ab213215SJeremy L Thompson //------------------------------------------------------------------------------
551c532df63SYohann inline __device__ void interp3d(const CeedInt nelem, const int transpose,
5527f823360Sjeremylt                                 const CeedScalar *c_B,
5537f823360Sjeremylt                                 const CeedScalar *__restrict__ d_U,
554c532df63SYohann                                 CeedScalar *__restrict__ d_V,
555c532df63SYohann                                 CeedScalar *slice) {
556c532df63SYohann   CeedScalar r_V[Q1D];
557c532df63SYohann   CeedScalar r_t[Q1D];
558c532df63SYohann 
559c532df63SYohann   const int tidx = threadIdx.x;
560c532df63SYohann   const int tidy = threadIdx.y;
561698ebc35SYohann Dudouit   const int tidz = threadIdx.z;
562698ebc35SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
563698ebc35SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
564698ebc35SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
565c532df63SYohann 
566698ebc35SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
567698ebc35SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
568c532df63SYohann     for (int i = 0; i < Q1D; ++i) {
569c532df63SYohann       r_V[i] = 0.0;
570c532df63SYohann       r_t[i] = 0.0;
571c532df63SYohann     }
572c532df63SYohann     if (!transpose) {
573c532df63SYohann       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
574698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
575698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
576698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
577c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
578c532df63SYohann     } else {
579c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
580698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
581698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
582698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
583c532df63SYohann       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
584c532df63SYohann     }
585c532df63SYohann   }
586c532df63SYohann }
587c532df63SYohann 
588ab213215SJeremy L Thompson //------------------------------------------------------------------------------
589ab213215SJeremy L Thompson // 3D derivatives at quadrature points
590ab213215SJeremy L Thompson //------------------------------------------------------------------------------
591c532df63SYohann inline __device__ void grad3d(const CeedInt nelem, const int transpose,
592c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
5937f823360Sjeremylt                               const CeedScalar *__restrict__ d_U,
5947f823360Sjeremylt                               CeedScalar *__restrict__ d_V,
595c532df63SYohann                               CeedScalar *slice) {
596ab213215SJeremy L Thompson   // Use P1D for one of these
597c532df63SYohann   CeedScalar r_U[Q1D];
598c532df63SYohann   CeedScalar r_V[Q1D];
599c532df63SYohann   CeedScalar r_t[Q1D];
600c532df63SYohann 
601c532df63SYohann   const int tidx = threadIdx.x;
602c532df63SYohann   const int tidy = threadIdx.y;
603698ebc35SYohann Dudouit   const int tidz = threadIdx.z;
604698ebc35SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
605698ebc35SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
606698ebc35SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
607c532df63SYohann   int dim;
608c532df63SYohann 
609698ebc35SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
610698ebc35SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
611c532df63SYohann     if (!transpose) {
612c532df63SYohann       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
613698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
614698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
615698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
616c532df63SYohann       dim = 0;
617c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
618698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
619698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
620698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
621c532df63SYohann       dim = 1;
622c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
623698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
624698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
625698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
626c532df63SYohann       dim = 2;
627c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
628c532df63SYohann     } else {
629c532df63SYohann       dim = 0;
630c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
631698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
632698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
633698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
634c532df63SYohann       dim = 1;
635c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
636698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
637698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
638698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
639c532df63SYohann       add(r_V, r_t);
640c532df63SYohann       dim = 2;
641c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
642698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
643698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
644698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
645c532df63SYohann       add(r_V, r_t);
646c532df63SYohann       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
647c532df63SYohann     }
648c532df63SYohann   }
649c532df63SYohann }
650c532df63SYohann 
651ab213215SJeremy L Thompson //------------------------------------------------------------------------------
652ab213215SJeremy L Thompson // 3D quadrature weights
653ab213215SJeremy L Thompson //------------------------------------------------------------------------------
654ab213215SJeremy L Thompson __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
655ab213215SJeremy L Thompson                          CeedScalar *w) {
656ab213215SJeremy L Thompson   const int i = threadIdx.x;
657ab213215SJeremy L Thompson   const int j = threadIdx.y;
658ab213215SJeremy L Thompson   const int k = threadIdx.z;
659ab213215SJeremy L Thompson   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
660ab213215SJeremy L Thompson   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
661ab213215SJeremy L Thompson     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
662ab213215SJeremy L Thompson     w[ind] = weight;
663ab213215SJeremy L Thompson   }
664ab213215SJeremy L Thompson }
665ab213215SJeremy L Thompson 
666ab213215SJeremy L Thompson 
667ab213215SJeremy L Thompson //------------------------------------------------------------------------------
668ab213215SJeremy L Thompson // Basis kernels
669ab213215SJeremy L Thompson //------------------------------------------------------------------------------
670ab213215SJeremy L Thompson 
671ab213215SJeremy L Thompson //------------------------------------------------------------------------------
672ab213215SJeremy L Thompson // Interp kernel by dim
673ab213215SJeremy L Thompson //------------------------------------------------------------------------------
674c532df63SYohann extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
6757f823360Sjeremylt                                   const CeedScalar *c_B,
6767f823360Sjeremylt                                   const CeedScalar *__restrict__ d_U,
677c532df63SYohann                                   CeedScalar *__restrict__ d_V) {
678074be161SYohann Dudouit   extern __shared__ double slice[];
679c532df63SYohann   if (BASIS_DIM == 1) {
680c532df63SYohann     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
681c532df63SYohann   } else if (BASIS_DIM == 2) {
682c532df63SYohann     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
683c532df63SYohann   } else if (BASIS_DIM == 3) {
684c532df63SYohann     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
685c532df63SYohann   }
686c532df63SYohann }
687c532df63SYohann 
688ab213215SJeremy L Thompson //------------------------------------------------------------------------------
689ab213215SJeremy L Thompson // Grad kernel by dim
690ab213215SJeremy L Thompson //------------------------------------------------------------------------------
691c532df63SYohann extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
692c532df63SYohann                                 const CeedScalar *c_B, const CeedScalar *c_G,
6937f823360Sjeremylt                                 const CeedScalar *__restrict__ d_U,
6947f823360Sjeremylt                                 CeedScalar *__restrict__ d_V) {
695074be161SYohann Dudouit   extern __shared__ double slice[];
696c532df63SYohann   if (BASIS_DIM == 1) {
697c532df63SYohann     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
698c532df63SYohann   } else if (BASIS_DIM == 2) {
699c532df63SYohann     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
700c532df63SYohann   } else if (BASIS_DIM == 3) {
701c532df63SYohann     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
702c532df63SYohann   }
703c532df63SYohann }
704c532df63SYohann 
705ab213215SJeremy L Thompson //------------------------------------------------------------------------------
706ab213215SJeremy L Thompson // Weight kernels by dim
707ab213215SJeremy L Thompson //------------------------------------------------------------------------------
708c532df63SYohann extern "C" __global__ void weight(const CeedInt nelem,
7097f823360Sjeremylt                                   const CeedScalar *__restrict__ qweight1d,
7107f823360Sjeremylt                                   CeedScalar *__restrict__ v) {
711c532df63SYohann   if (BASIS_DIM == 1) {
712c532df63SYohann     weight1d(nelem, qweight1d, v);
713c532df63SYohann   } else if (BASIS_DIM == 2) {
714c532df63SYohann     weight2d(nelem, qweight1d, v);
715c532df63SYohann   } else if (BASIS_DIM == 3) {
716c532df63SYohann     weight3d(nelem, qweight1d, v);
717c532df63SYohann   }
718c532df63SYohann }
719c532df63SYohann 
720c532df63SYohann );
721cb0b5415Sjeremylt // *INDENT-ON*
722c532df63SYohann 
723ab213215SJeremy L Thompson //------------------------------------------------------------------------------
724ab213215SJeremy L Thompson // Device initalization
725ab213215SJeremy L Thompson //------------------------------------------------------------------------------
726c532df63SYohann int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
727c532df63SYohann                        CeedScalar **c_B);
728c532df63SYohann int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
7297f823360Sjeremylt                            CeedInt Q1d, CeedScalar **c_B_ptr,
7307f823360Sjeremylt                            CeedScalar **c_G_ptr);
731c532df63SYohann 
732ab213215SJeremy L Thompson //------------------------------------------------------------------------------
733ab213215SJeremy L Thompson // Apply basis
734ab213215SJeremy L Thompson //------------------------------------------------------------------------------
735c532df63SYohann int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem,
736c532df63SYohann                                      CeedTransposeMode tmode,
7377f823360Sjeremylt                                      CeedEvalMode emode, CeedVector u,
7387f823360Sjeremylt                                      CeedVector v) {
739c532df63SYohann   int ierr;
740c532df63SYohann   Ceed ceed;
741c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
742c532df63SYohann   Ceed_Cuda_shared *ceed_Cuda;
743c532df63SYohann   CeedGetData(ceed, (void *) &ceed_Cuda); CeedChk(ierr);
744c532df63SYohann   CeedBasis_Cuda_shared *data;
745c532df63SYohann   CeedBasisGetData(basis, (void *)&data); CeedChk(ierr);
746c532df63SYohann   const CeedInt transpose = tmode == CEED_TRANSPOSE;
7474247ecf3SYohann Dudouit   CeedInt dim, ncomp;
748074be161SYohann Dudouit   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
7494247ecf3SYohann Dudouit   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
750c532df63SYohann 
751ab213215SJeremy L Thompson   // Read vectors
752c532df63SYohann   const CeedScalar *d_u;
753c532df63SYohann   CeedScalar *d_v;
754c532df63SYohann   if (emode != CEED_EVAL_WEIGHT) {
755c532df63SYohann     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
756c532df63SYohann   }
757c532df63SYohann   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
758c532df63SYohann 
759ab213215SJeremy L Thompson   // Clear v for transpose mode
760c532df63SYohann   if (tmode == CEED_TRANSPOSE) {
761c532df63SYohann     CeedInt length;
762c532df63SYohann     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
763c532df63SYohann     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
764c532df63SYohann   }
765ab213215SJeremy L Thompson 
766ab213215SJeremy L Thompson   // Apply basis operation
767ab213215SJeremy L Thompson   switch (emode) {
768ab213215SJeremy L Thompson   case CEED_EVAL_INTERP: {
769c532df63SYohann     CeedInt P1d, Q1d;
770c532df63SYohann     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
771c532df63SYohann     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
772c532df63SYohann     ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
773c532df63SYohann     CeedChk(ierr);
774cb0b5415Sjeremylt     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
775ccf0fe6fSjeremylt                           &d_u, &d_v
776ccf0fe6fSjeremylt                          };
7774d537eeaSYohann     if (dim == 1) {
778d94769d2SYohann Dudouit       CeedInt elemsPerBlock = 32;
7794d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
7804d537eeaSYohann                                              ? 1 : 0 );
781d94769d2SYohann Dudouit       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
7824d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, 1,
7834d537eeaSYohann                                         elemsPerBlock, sharedMem,
784ab213215SJeremy L Thompson                                         interpargs); CeedChk(ierr);
785074be161SYohann Dudouit     } else if (dim == 2) {
7864247ecf3SYohann Dudouit       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
7874247ecf3SYohann Dudouit       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
7884d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
7894d537eeaSYohann                                              ? 1 : 0 );
7904247ecf3SYohann Dudouit       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
7914d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d,
7924d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
793ab213215SJeremy L Thompson                                         interpargs); CeedChk(ierr);
794074be161SYohann Dudouit     } else if (dim == 3) {
7953f63d318SYohann Dudouit       CeedInt elemsPerBlock = 1;
7964d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
7974d537eeaSYohann                                              ? 1 : 0 );
798698ebc35SYohann Dudouit       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
7994d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d,
8004d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
801ab213215SJeremy L Thompson                                         interpargs); CeedChk(ierr);
802074be161SYohann Dudouit     }
803ab213215SJeremy L Thompson   } break;
804ab213215SJeremy L Thompson   case CEED_EVAL_GRAD: {
805c532df63SYohann     CeedInt P1d, Q1d;
806c532df63SYohann     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
807c532df63SYohann     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
808c532df63SYohann     ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
809c532df63SYohann                                   Q1d, &data->c_B, &data->c_G);
810c532df63SYohann     CeedChk(ierr);
811cb0b5415Sjeremylt     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
812ccf0fe6fSjeremylt                         &data->c_G, &d_u, &d_v
813ccf0fe6fSjeremylt                        };
8144d537eeaSYohann     if (dim == 1) {
815d94769d2SYohann Dudouit       CeedInt elemsPerBlock = 32;
8164d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8174d537eeaSYohann                                              ? 1 : 0 );
818d94769d2SYohann Dudouit       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
819ab213215SJeremy L Thompson       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, 1,
820ab213215SJeremy L Thompson                                         elemsPerBlock, sharedMem, gradargs);
821c532df63SYohann       CeedChk(ierr);
822074be161SYohann Dudouit     } else if (dim == 2) {
8234247ecf3SYohann Dudouit       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
8244247ecf3SYohann Dudouit       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
8254d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8264d537eeaSYohann                                              ? 1 : 0 );
8274247ecf3SYohann Dudouit       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
8284d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d,
8294d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
830ab213215SJeremy L Thompson                                         gradargs); CeedChk(ierr);
831074be161SYohann Dudouit     } else if (dim == 3) {
8323f63d318SYohann Dudouit       CeedInt elemsPerBlock = 1;
8334d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8344d537eeaSYohann                                              ? 1 : 0 );
835698ebc35SYohann Dudouit       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
8364d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d,
8374d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
838ab213215SJeremy L Thompson                                         gradargs); CeedChk(ierr);
839074be161SYohann Dudouit     }
840ab213215SJeremy L Thompson   } break;
841ab213215SJeremy L Thompson   case CEED_EVAL_WEIGHT: {
842074be161SYohann Dudouit     CeedInt Q1d;
843074be161SYohann Dudouit     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
844c532df63SYohann     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
845074be161SYohann Dudouit     if (dim == 1) {
846074be161SYohann Dudouit       const CeedInt elemsPerBlock = 32/Q1d;
8474d537eeaSYohann       const CeedInt gridsize = nelem/elemsPerBlock + ( (
8484d537eeaSYohann                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
8497f823360Sjeremylt       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d,
8507f823360Sjeremylt                                   elemsPerBlock, 1, weightargs);
8511226057fSYohann Dudouit       CeedChk(ierr);
852074be161SYohann Dudouit     } else if (dim == 2) {
853717ff8a3SYohann Dudouit       const CeedInt optElems = 32/(Q1d*Q1d);
854717ff8a3SYohann Dudouit       const CeedInt elemsPerBlock = optElems>0?optElems:1;
8554d537eeaSYohann       const CeedInt gridsize = nelem/elemsPerBlock + ( (
8564d537eeaSYohann                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
8574d537eeaSYohann       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d,
8584d537eeaSYohann                                   elemsPerBlock, weightargs);
8591226057fSYohann Dudouit       CeedChk(ierr);
860074be161SYohann Dudouit     } else if (dim == 3) {
861074be161SYohann Dudouit       const CeedInt gridsize = nelem;
8624d537eeaSYohann       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
8634d537eeaSYohann                                   weightargs);
8641226057fSYohann Dudouit       CeedChk(ierr);
865074be161SYohann Dudouit     }
866ab213215SJeremy L Thompson   } break;
867ab213215SJeremy L Thompson   // LCOV_EXCL_START
868ab213215SJeremy L Thompson   // Evaluate the divergence to/from the quadrature points
869ab213215SJeremy L Thompson   case CEED_EVAL_DIV:
870ab213215SJeremy L Thompson     return CeedError(ceed, 1, "CEED_EVAL_DIV not supported");
871ab213215SJeremy L Thompson   // Evaluate the curl to/from the quadrature points
872ab213215SJeremy L Thompson   case CEED_EVAL_CURL:
873ab213215SJeremy L Thompson     return CeedError(ceed, 1, "CEED_EVAL_CURL not supported");
874ab213215SJeremy L Thompson   // Take no action, BasisApply should not have been called
875ab213215SJeremy L Thompson   case CEED_EVAL_NONE:
876ab213215SJeremy L Thompson     return CeedError(ceed, 1,
877ab213215SJeremy L Thompson                      "CEED_EVAL_NONE does not make sense in this context");
878ab213215SJeremy L Thompson     // LCOV_EXCL_STOP
879c532df63SYohann   }
880c532df63SYohann 
881ab213215SJeremy L Thompson   // Restore vectors
882c532df63SYohann   if (emode != CEED_EVAL_WEIGHT) {
883c532df63SYohann     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
884c532df63SYohann   }
885c532df63SYohann   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
886c532df63SYohann   return 0;
887c532df63SYohann }
888c532df63SYohann 
889ab213215SJeremy L Thompson //------------------------------------------------------------------------------
890ab213215SJeremy L Thompson // Destroy basis
891ab213215SJeremy L Thompson //------------------------------------------------------------------------------
892c532df63SYohann static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
893c532df63SYohann   int ierr;
894c532df63SYohann   Ceed ceed;
895c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
896c532df63SYohann 
897c532df63SYohann   CeedBasis_Cuda_shared *data;
898c532df63SYohann   ierr = CeedBasisGetData(basis, (void *) &data); CeedChk(ierr);
899c532df63SYohann 
900c532df63SYohann   CeedChk_Cu(ceed, cuModuleUnload(data->module));
901c532df63SYohann 
902c532df63SYohann   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr);
903c532df63SYohann   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr);
904c532df63SYohann   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr);
905c532df63SYohann 
906c532df63SYohann   ierr = CeedFree(&data); CeedChk(ierr);
907c532df63SYohann 
908c532df63SYohann   return 0;
909c532df63SYohann }
910c532df63SYohann 
911ab213215SJeremy L Thompson //------------------------------------------------------------------------------
912ab213215SJeremy L Thompson // Create tensor basis
913ab213215SJeremy L Thompson //------------------------------------------------------------------------------
914c532df63SYohann int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
915c532df63SYohann                                         const CeedScalar *interp1d,
916c532df63SYohann                                         const CeedScalar *grad1d,
917c532df63SYohann                                         const CeedScalar *qref1d,
918c532df63SYohann                                         const CeedScalar *qweight1d,
919c532df63SYohann                                         CeedBasis basis) {
920c532df63SYohann   int ierr;
921c532df63SYohann   Ceed ceed;
922c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
9234d537eeaSYohann   if (Q1d<P1d) {
9241226057fSYohann Dudouit     return CeedError(ceed, 1, "Backend does not implement underintegrated basis.");
9251226057fSYohann Dudouit   }
926c532df63SYohann   CeedBasis_Cuda_shared *data;
927c532df63SYohann   ierr = CeedCalloc(1, &data); CeedChk(ierr);
928c532df63SYohann 
929ab213215SJeremy L Thompson   // Copy basis data to GPU
930c532df63SYohann   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
931c532df63SYohann   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr);
932c532df63SYohann   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
933c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
934c532df63SYohann 
935c532df63SYohann   const CeedInt iBytes = qBytes * P1d;
936c532df63SYohann   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr);
937c532df63SYohann   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
938c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
939c532df63SYohann 
940c532df63SYohann   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr);
941c532df63SYohann   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
942c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
943c532df63SYohann 
944ab213215SJeremy L Thompson   // Compute collocated gradient and copy to GPU
945ac421f39SYohann   data->d_collograd1d = NULL;
946ac421f39SYohann   if (dim == 3 && Q1d >= P1d) {
947ac421f39SYohann     CeedScalar *collograd1d;
948ac421f39SYohann     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChk(ierr);
949ac421f39SYohann     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChk(ierr);
950ac421f39SYohann     ierr = cudaMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
951ac421f39SYohann     CeedChk_Cu(ceed, ierr);
952ac421f39SYohann     ierr = cudaMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
953ac421f39SYohann                       cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
954ac421f39SYohann   }
955ac421f39SYohann 
956ab213215SJeremy L Thompson   // Compile basis kernels
957c532df63SYohann   CeedInt ncomp;
958c532df63SYohann   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
9594a6d4bbdSYohann Dudouit   ierr = CeedCompileCuda(ceed, kernelsShared, &data->module, 7,
960c532df63SYohann                          "Q1D", Q1d,
961c532df63SYohann                          "P1D", P1d,
962c532df63SYohann                          "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
963c532df63SYohann                              Q1d : P1d, dim),
964c532df63SYohann                          "BASIS_DIM", dim,
965c532df63SYohann                          "BASIS_NCOMP", ncomp,
966c532df63SYohann                          "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
967c532df63SYohann                          "BASIS_NQPT", CeedIntPow(Q1d, dim)
968c532df63SYohann                         ); CeedChk(ierr);
9694a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp);
970c532df63SYohann   CeedChk(ierr);
9714a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad);
972c532df63SYohann   CeedChk(ierr);
9734a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight);
974c532df63SYohann   CeedChk(ierr);
975c532df63SYohann 
976ab213215SJeremy L Thompson   ierr = CeedBasisSetData(basis, (void *)&data); CeedChk(ierr);
977ab213215SJeremy L Thompson 
978ab213215SJeremy L Thompson   // Register backend functions
979c532df63SYohann   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
980c532df63SYohann                                 CeedBasisApplyTensor_Cuda_shared);
981c532df63SYohann   CeedChk(ierr);
982c532df63SYohann   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
983ab213215SJeremy L Thompson                                 CeedBasisDestroy_Cuda_shared); CeedChk(ierr);
984c532df63SYohann   return 0;
985c532df63SYohann }
986ab213215SJeremy L Thompson //------------------------------------------------------------------------------
987