xref: /libCEED/rust/libceed-sys/c-src/backends/cuda-shared/ceed-cuda-shared-basis.c (revision 18d499f16311f0272db1a57764883905102e8c42)
1c532df63SYohann // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2c532df63SYohann // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3c532df63SYohann // All Rights reserved. See files LICENSE and NOTICE for details.
4c532df63SYohann //
5c532df63SYohann // This file is part of CEED, a collection of benchmarks, miniapps, software
6c532df63SYohann // libraries and APIs for efficient high-order finite element and spectral
7c532df63SYohann // element discretizations for exascale applications. For more information and
8c532df63SYohann // source code availability see http://github.com/ceed.
9c532df63SYohann //
10c532df63SYohann // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11c532df63SYohann // a collaborative effort of two U.S. Department of Energy organizations (Office
12c532df63SYohann // of Science and the National Nuclear Security Administration) responsible for
13c532df63SYohann // the planning and preparation of a capable exascale ecosystem, including
14c532df63SYohann // software, applications, hardware, advanced system engineering and early
15c532df63SYohann // testbed platforms, in support of the nation's exascale computing imperative.
16c532df63SYohann 
17c532df63SYohann #include "ceed-cuda-shared.h"
18c532df63SYohann 
19ab213215SJeremy L Thompson //------------------------------------------------------------------------------
20ab213215SJeremy L Thompson // Shared mem kernels
21ab213215SJeremy L Thompson //------------------------------------------------------------------------------
22cb0b5415Sjeremylt // *INDENT-OFF*
23c532df63SYohann static const char *kernelsShared = QUOTE(
24c532df63SYohann 
25ab213215SJeremy L Thompson //------------------------------------------------------------------------------
26ab213215SJeremy L Thompson // Sum input into output
27ab213215SJeremy L Thompson //------------------------------------------------------------------------------
28c532df63SYohann inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
29*18d499f1SYohann   for (int i = 0; i < P1D; i++)
30c532df63SYohann     r_V[i] += r_U[i];
31c532df63SYohann }
32c532df63SYohann 
33ab213215SJeremy L Thompson //------------------------------------------------------------------------------
34ab213215SJeremy L Thompson // 1D
35ab213215SJeremy L Thompson //------------------------------------------------------------------------------
36c532df63SYohann 
37ab213215SJeremy L Thompson //------------------------------------------------------------------------------
38ab213215SJeremy L Thompson // Read DoFs
39ab213215SJeremy L Thompson //------------------------------------------------------------------------------
40c532df63SYohann inline __device__ void readDofs1d(const int elem, const int tidx,
41d94769d2SYohann Dudouit                                   const int tidy, const int tidz,const int comp,
427f823360Sjeremylt                                   const int nelem, const CeedScalar *d_U,
437f823360Sjeremylt                                   CeedScalar *slice) {
44c532df63SYohann   for (int i = 0; i < P1D; i++)
45*18d499f1SYohann     slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem];
46c532df63SYohann   for (int i = P1D; i < Q1D; i++)
47*18d499f1SYohann     slice[i + tidz*T1D] = 0.0;
48c532df63SYohann }
49c532df63SYohann 
50ab213215SJeremy L Thompson //------------------------------------------------------------------------------
51ab213215SJeremy L Thompson // Write DoFs
52ab213215SJeremy L Thompson //------------------------------------------------------------------------------
53c532df63SYohann inline __device__ void writeDofs1d(const int elem, const int tidx,
54c532df63SYohann                                    const int tidy, const int comp,
55288c0443SJeremy L Thompson                                    const int nelem, const CeedScalar &r_V,
56288c0443SJeremy L Thompson                                    CeedScalar *d_V) {
57ab213215SJeremy L Thompson   if (tidx<P1D)
5849fd234cSJeremy L Thompson     d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V;
59c532df63SYohann }
60c532df63SYohann 
61ab213215SJeremy L Thompson //------------------------------------------------------------------------------
62ab213215SJeremy L Thompson // Read quadrature point data
63ab213215SJeremy L Thompson //------------------------------------------------------------------------------
64c532df63SYohann inline __device__ void readQuads1d(const int elem, const int tidx,
65d94769d2SYohann Dudouit                                    const int tidy, const int tidz, const int comp,
66288c0443SJeremy L Thompson                                    const int dim, const int nelem,
67288c0443SJeremy L Thompson                                    const CeedScalar *d_U, CeedScalar *slice) {
68c532df63SYohann   for (int i = 0; i < Q1D; i++)
69*18d499f1SYohann     slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
704d537eeaSYohann                             dim*BASIS_NCOMP*nelem*Q1D];
71*18d499f1SYohann   for (int i = Q1D; i < P1D; i++)
72*18d499f1SYohann     slice[i + tidz*T1D] = 0.0;
73c532df63SYohann }
74c532df63SYohann 
75ab213215SJeremy L Thompson //------------------------------------------------------------------------------
76ab213215SJeremy L Thompson // Write quadrature point data
77ab213215SJeremy L Thompson //------------------------------------------------------------------------------
78c532df63SYohann inline __device__ void writeQuads1d(const int elem, const int tidx,
79c532df63SYohann                                     const int tidy, const int comp,
80288c0443SJeremy L Thompson                                     const int dim, const int nelem,
81288c0443SJeremy L Thompson                                     const CeedScalar &r_V, CeedScalar *d_V) {
82*18d499f1SYohann   if (tidx<Q1D)
83c532df63SYohann     d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
84c532df63SYohann }
85c532df63SYohann 
86ab213215SJeremy L Thompson //------------------------------------------------------------------------------
87ab213215SJeremy L Thompson // 1D tensor contraction
88ab213215SJeremy L Thompson //------------------------------------------------------------------------------
89c532df63SYohann inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
90d94769d2SYohann Dudouit                                    const int tidy, const int tidz,
91288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
92288c0443SJeremy L Thompson                                    CeedScalar &V) {
93c532df63SYohann   V = 0.0;
94ab213215SJeremy L Thompson   for (int i = 0; i < P1D; ++i)
95*18d499f1SYohann     V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction
96c532df63SYohann }
97c532df63SYohann 
98ab213215SJeremy L Thompson //------------------------------------------------------------------------------
99ab213215SJeremy L Thompson // 1D transpose tensor contraction
100ab213215SJeremy L Thompson //------------------------------------------------------------------------------
101c532df63SYohann inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
102d94769d2SYohann Dudouit     const int tidy, const int tidz,
103c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
104c532df63SYohann   V = 0.0;
105ab213215SJeremy L Thompson   for (int i = 0; i < Q1D; ++i)
106*18d499f1SYohann     V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction
107c532df63SYohann }
108c532df63SYohann 
109ab213215SJeremy L Thompson //------------------------------------------------------------------------------
110ab213215SJeremy L Thompson // 1D interpolate to quadrature points
111ab213215SJeremy L Thompson //------------------------------------------------------------------------------
112c532df63SYohann inline __device__ void interp1d(const CeedInt nelem, const int transpose,
113288c0443SJeremy L Thompson                                 const CeedScalar *c_B,
114288c0443SJeremy L Thompson                                 const CeedScalar *__restrict__ d_U,
115c532df63SYohann                                 CeedScalar *__restrict__ d_V,
116c532df63SYohann                                 CeedScalar *slice) {
117c532df63SYohann   CeedScalar r_V;
118c532df63SYohann   CeedScalar r_t;
119c532df63SYohann 
120c532df63SYohann   const int tidx = threadIdx.x;
121c532df63SYohann   const int tidy = threadIdx.y;
122d94769d2SYohann Dudouit   const int tidz = threadIdx.z;
123c532df63SYohann 
124c532df63SYohann 
125c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
126c532df63SYohann        elem += gridDim.x*blockDim.z) {
127c532df63SYohann     for (int comp = 0; comp < BASIS_NCOMP; comp++) {
128c532df63SYohann       if (!transpose) {
129d94769d2SYohann Dudouit         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
130d94769d2SYohann Dudouit         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
131c532df63SYohann         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
132c532df63SYohann       } else {
133d94769d2SYohann Dudouit         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
134d94769d2SYohann Dudouit         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
135c532df63SYohann         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
136c532df63SYohann       }
137c532df63SYohann     }
138c532df63SYohann   }
139c532df63SYohann }
140c532df63SYohann 
141ab213215SJeremy L Thompson //------------------------------------------------------------------------------
142ab213215SJeremy L Thompson // 1D derivatives at quadrature points
143ab213215SJeremy L Thompson //------------------------------------------------------------------------------
144c532df63SYohann inline __device__ void grad1d(const CeedInt nelem, const int transpose,
145c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
146288c0443SJeremy L Thompson                               const CeedScalar *__restrict__ d_U,
147288c0443SJeremy L Thompson                               CeedScalar *__restrict__ d_V,
148c532df63SYohann                               CeedScalar *slice) {
149c532df63SYohann   CeedScalar r_U;
150c532df63SYohann   CeedScalar r_V;
151c532df63SYohann 
152c532df63SYohann   const int tidx = threadIdx.x;
153d94769d2SYohann Dudouit   const int tidy = threadIdx.y;
154d94769d2SYohann Dudouit   const int tidz = threadIdx.z;
155c532df63SYohann   int dim;
156c532df63SYohann 
157c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
158c532df63SYohann        elem += gridDim.x*blockDim.z) {
159c532df63SYohann     for(int comp = 0; comp < BASIS_NCOMP; comp++) {
160c532df63SYohann       if (!transpose) {
161d94769d2SYohann Dudouit         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
162d94769d2SYohann Dudouit         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
163c532df63SYohann         dim = 0;
164c532df63SYohann         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
165c532df63SYohann       } else {
166c532df63SYohann         dim = 0;
167d94769d2SYohann Dudouit         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
168d94769d2SYohann Dudouit         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
169c532df63SYohann         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
170c532df63SYohann       }
171c532df63SYohann     }
172c532df63SYohann   }
173c532df63SYohann }
174c532df63SYohann 
175ab213215SJeremy L Thompson //------------------------------------------------------------------------------
176ab213215SJeremy L Thompson // 1D Quadrature weights
177ab213215SJeremy L Thompson //------------------------------------------------------------------------------
178ab213215SJeremy L Thompson __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
179ab213215SJeremy L Thompson                          CeedScalar *w) {
180ab213215SJeremy L Thompson   const int tid = threadIdx.x;
181ab213215SJeremy L Thompson   const CeedScalar weight = qweight1d[tid];
182ab213215SJeremy L Thompson   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
183ab213215SJeremy L Thompson        elem += gridDim.x*blockDim.y) {
184ab213215SJeremy L Thompson     const int ind = elem*Q1D + tid;
185ab213215SJeremy L Thompson     w[ind] = weight;
186ab213215SJeremy L Thompson   }
187ab213215SJeremy L Thompson }
188ab213215SJeremy L Thompson 
189ab213215SJeremy L Thompson //------------------------------------------------------------------------------
190ab213215SJeremy L Thompson // 2D
191ab213215SJeremy L Thompson //------------------------------------------------------------------------------
192ab213215SJeremy L Thompson 
193ab213215SJeremy L Thompson //------------------------------------------------------------------------------
194ab213215SJeremy L Thompson // Read DoFs
195ab213215SJeremy L Thompson //------------------------------------------------------------------------------
196c532df63SYohann inline __device__ void readDofs2d(const int elem, const int tidx,
197c532df63SYohann                                   const int tidy, const int comp,
198288c0443SJeremy L Thompson                                   const int nelem, const CeedScalar *d_U,
199288c0443SJeremy L Thompson                                   CeedScalar &U) {
20049fd234cSJeremy L Thompson   U = (tidx<P1D && tidy<P1D) ?
20149fd234cSJeremy L Thompson       d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0;
202c532df63SYohann }
203c532df63SYohann 
204ab213215SJeremy L Thompson //------------------------------------------------------------------------------
205ab213215SJeremy L Thompson // Write DoFs
206ab213215SJeremy L Thompson //------------------------------------------------------------------------------
207c532df63SYohann inline __device__ void writeDofs2d(const int elem, const int tidx,
208c532df63SYohann                                    const int tidy, const int comp,
209288c0443SJeremy L Thompson                                    const int nelem, const CeedScalar &r_V,
210288c0443SJeremy L Thompson                                    CeedScalar *d_V) {
211ab213215SJeremy L Thompson   if (tidx<P1D && tidy<P1D)
21249fd234cSJeremy L Thompson     d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V;
213c532df63SYohann }
214c532df63SYohann 
215ab213215SJeremy L Thompson //------------------------------------------------------------------------------
216ab213215SJeremy L Thompson // Read quadrature point data
217ab213215SJeremy L Thompson //------------------------------------------------------------------------------
218c532df63SYohann inline __device__ void readQuads2d(const int elem, const int tidx,
219c532df63SYohann                                    const int tidy, const int comp,
220288c0443SJeremy L Thompson                                    const int dim, const int nelem,
221288c0443SJeremy L Thompson                                    const CeedScalar *d_U, CeedScalar &U ) {
222*18d499f1SYohann   U = (tidx<Q1D && tidy<Q1D) ?
223*18d499f1SYohann       d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
224*18d499f1SYohann       dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0;
225c532df63SYohann }
226c532df63SYohann 
227ab213215SJeremy L Thompson //------------------------------------------------------------------------------
228ab213215SJeremy L Thompson // Write quadrature point data
229ab213215SJeremy L Thompson //------------------------------------------------------------------------------
230c532df63SYohann inline __device__ void writeQuads2d(const int elem, const int tidx,
231c532df63SYohann                                     const int tidy, const int comp,
232288c0443SJeremy L Thompson                                     const int dim, const int nelem,
233288c0443SJeremy L Thompson                                     const CeedScalar &r_V, CeedScalar *d_V) {
234*18d499f1SYohann   if (tidx<Q1D && tidy<Q1D)
235c532df63SYohann     d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
236c532df63SYohann     dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V;
237c532df63SYohann }
238c532df63SYohann 
239ab213215SJeremy L Thompson //------------------------------------------------------------------------------
240ab213215SJeremy L Thompson // 2D tensor contraction x
241ab213215SJeremy L Thompson //------------------------------------------------------------------------------
242c532df63SYohann inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
2434247ecf3SYohann Dudouit                                    const int tidy, const int tidz,
244288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
245288c0443SJeremy L Thompson                                    CeedScalar &V) {
246*18d499f1SYohann   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
247c532df63SYohann   __syncthreads();
248c532df63SYohann   V = 0.0;
249*18d499f1SYohann   if (tidx < Q1D)
250ab213215SJeremy L Thompson     for (int i = 0; i < P1D; ++i)
251*18d499f1SYohann       V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
252c532df63SYohann   __syncthreads();
253c532df63SYohann }
254c532df63SYohann 
255ab213215SJeremy L Thompson //------------------------------------------------------------------------------
256ab213215SJeremy L Thompson // 2D tensor contraction y
257ab213215SJeremy L Thompson //------------------------------------------------------------------------------
258c532df63SYohann inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
2594247ecf3SYohann Dudouit                                    const int tidy, const int tidz,
260288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
261288c0443SJeremy L Thompson                                    CeedScalar &V) {
262*18d499f1SYohann   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
263c532df63SYohann   __syncthreads();
264c532df63SYohann   V = 0.0;
265*18d499f1SYohann   if (tidy < Q1D)
266ab213215SJeremy L Thompson     for (int i = 0; i < P1D; ++i)
267*18d499f1SYohann       V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
268c532df63SYohann   __syncthreads();
269c532df63SYohann }
270c532df63SYohann 
271ab213215SJeremy L Thompson //------------------------------------------------------------------------------
272ab213215SJeremy L Thompson // 2D transpose tensor contraction y
273ab213215SJeremy L Thompson //------------------------------------------------------------------------------
274c532df63SYohann inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
2754247ecf3SYohann Dudouit     const int tidy, const int tidz,
276c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
277*18d499f1SYohann   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
278c532df63SYohann   __syncthreads();
279c532df63SYohann   V = 0.0;
280ab213215SJeremy L Thompson   if (tidy < P1D)
281ab213215SJeremy L Thompson     for (int i = 0; i < Q1D; ++i)
282*18d499f1SYohann       V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
283c532df63SYohann   __syncthreads();
284c532df63SYohann }
285c532df63SYohann 
286ab213215SJeremy L Thompson //------------------------------------------------------------------------------
287ab213215SJeremy L Thompson // 2D transpose tensor contraction x
288ab213215SJeremy L Thompson //------------------------------------------------------------------------------
289c532df63SYohann inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
2904247ecf3SYohann Dudouit     const int tidy, const int tidz,
291c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
292*18d499f1SYohann   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
293c532df63SYohann   __syncthreads();
294c532df63SYohann   V = 0.0;
295ab213215SJeremy L Thompson   if (tidx < P1D)
296ab213215SJeremy L Thompson     for (int i = 0; i < Q1D; ++i)
297*18d499f1SYohann       V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
298c532df63SYohann   __syncthreads();
299c532df63SYohann }
300c532df63SYohann 
301ab213215SJeremy L Thompson //------------------------------------------------------------------------------
302ab213215SJeremy L Thompson // 2D interpolate to quadrature points
303ab213215SJeremy L Thompson //------------------------------------------------------------------------------
304c532df63SYohann inline __device__ void interp2d(const CeedInt nelem, const int transpose,
305288c0443SJeremy L Thompson                                 const CeedScalar *c_B,
306288c0443SJeremy L Thompson                                 const CeedScalar *__restrict__ d_U,
307c532df63SYohann                                 CeedScalar *__restrict__ d_V,
308c532df63SYohann                                 CeedScalar *slice) {
309c532df63SYohann   CeedScalar r_V;
310c532df63SYohann   CeedScalar r_t;
311c532df63SYohann 
312c532df63SYohann   const int tidx = threadIdx.x;
313c532df63SYohann   const int tidy = threadIdx.y;
3144247ecf3SYohann Dudouit   const int tidz = threadIdx.z;
3154247ecf3SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
3164247ecf3SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
3174247ecf3SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
318c532df63SYohann 
3194247ecf3SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
3204247ecf3SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
3214247ecf3SYohann Dudouit     const int comp = tidz%BASIS_NCOMP;
322c532df63SYohann     r_V = 0.0;
323c532df63SYohann     r_t = 0.0;
324c532df63SYohann     if (!transpose) {
325c532df63SYohann       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
3264247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
3274247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
328c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
329c532df63SYohann     } else {
330c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
3314247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
3324247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
333c532df63SYohann       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
334c532df63SYohann     }
335c532df63SYohann   }
336c532df63SYohann }
337c532df63SYohann 
338ab213215SJeremy L Thompson //------------------------------------------------------------------------------
339ab213215SJeremy L Thompson // 2D derivatives at quadrature points
340ab213215SJeremy L Thompson //------------------------------------------------------------------------------
341c532df63SYohann inline __device__ void grad2d(const CeedInt nelem, const int transpose,
342c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
3437f823360Sjeremylt                               const CeedScalar *__restrict__ d_U,
3447f823360Sjeremylt                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
345c532df63SYohann   CeedScalar r_U;
346c532df63SYohann   CeedScalar r_V;
347c532df63SYohann   CeedScalar r_t;
348c532df63SYohann 
349c532df63SYohann   const int tidx = threadIdx.x;
350c532df63SYohann   const int tidy = threadIdx.y;
3514247ecf3SYohann Dudouit   const int tidz = threadIdx.z;
3524247ecf3SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
3534247ecf3SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
3544247ecf3SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
355c532df63SYohann   int dim;
356c532df63SYohann 
3574247ecf3SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
3584247ecf3SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
359c532df63SYohann     if (!transpose) {
360c532df63SYohann       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
3614247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
3624247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
363c532df63SYohann       dim = 0;
364c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
3654247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
3664247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
367c532df63SYohann       dim = 1;
368c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
369c532df63SYohann     } else {
370c532df63SYohann       dim = 0;
371c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
3724247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
3734247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
374c532df63SYohann       dim = 1;
375c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
3764247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
3774247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
378c532df63SYohann       r_V += r_U;
379c532df63SYohann       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
380c532df63SYohann     }
381c532df63SYohann   }
382c532df63SYohann }
383c532df63SYohann 
384ab213215SJeremy L Thompson //------------------------------------------------------------------------------
385ab213215SJeremy L Thompson // 2D quadrature weights
386ab213215SJeremy L Thompson //------------------------------------------------------------------------------
387ab213215SJeremy L Thompson __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
388ab213215SJeremy L Thompson                          CeedScalar *w) {
389ab213215SJeremy L Thompson   const int i = threadIdx.x;
390ab213215SJeremy L Thompson   const int j = threadIdx.y;
391ab213215SJeremy L Thompson   const CeedScalar weight = qweight1d[i]*qweight1d[j];
392ab213215SJeremy L Thompson   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
393ab213215SJeremy L Thompson        elem += gridDim.x*blockDim.z) {
394ab213215SJeremy L Thompson     const int ind = elem*Q1D*Q1D + i + j*Q1D;
395ab213215SJeremy L Thompson     w[ind] = weight;
396ab213215SJeremy L Thompson   }
397ab213215SJeremy L Thompson }
398ab213215SJeremy L Thompson 
399ab213215SJeremy L Thompson //------------------------------------------------------------------------------
400ab213215SJeremy L Thompson // 3D
401ab213215SJeremy L Thompson //------------------------------------------------------------------------------
402ab213215SJeremy L Thompson 
403ab213215SJeremy L Thompson //------------------------------------------------------------------------------
404ab213215SJeremy L Thompson // Read DoFs
405ab213215SJeremy L Thompson //------------------------------------------------------------------------------
406c532df63SYohann inline __device__ void readDofs3d(const int elem, const int tidx,
407c532df63SYohann                                   const int tidy, const int comp,
4087f823360Sjeremylt                                   const int nelem, const CeedScalar *d_U,
4097f823360Sjeremylt                                   CeedScalar *r_U) {
410c532df63SYohann   for (int i = 0; i < P1D; i++)
411ab213215SJeremy L Thompson     r_U[i] = (tidx < P1D && tidy < P1D) ?
41249fd234cSJeremy L Thompson               d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
41349fd234cSJeremy L Thompson                   comp*P1D*P1D*P1D*nelem] : 0.0;
414c532df63SYohann   for (int i = P1D; i < Q1D; i++)
415c532df63SYohann     r_U[i] = 0.0;
416c532df63SYohann }
417c532df63SYohann 
418ab213215SJeremy L Thompson //------------------------------------------------------------------------------
41949fd234cSJeremy L Thompson // Write DoFs
42049fd234cSJeremy L Thompson //------------------------------------------------------------------------------
42149fd234cSJeremy L Thompson inline __device__ void writeDofs3d(const int elem, const int tidx,
42249fd234cSJeremy L Thompson                                    const int tidy, const int comp,
42349fd234cSJeremy L Thompson                                    const int nelem, const CeedScalar *r_V,
42449fd234cSJeremy L Thompson                                    CeedScalar *d_V) {
42549fd234cSJeremy L Thompson   if (tidx < P1D && tidy < P1D) {
42649fd234cSJeremy L Thompson     for (int i = 0; i < P1D; i++)
42749fd234cSJeremy L Thompson       d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
42849fd234cSJeremy L Thompson           comp*P1D*P1D*P1D*nelem] = r_V[i];
42949fd234cSJeremy L Thompson   }
43049fd234cSJeremy L Thompson }
43149fd234cSJeremy L Thompson 
43249fd234cSJeremy L Thompson //------------------------------------------------------------------------------
433ab213215SJeremy L Thompson // Read quadrature point data
434ab213215SJeremy L Thompson //------------------------------------------------------------------------------
435c532df63SYohann inline __device__ void readQuads3d(const int elem, const int tidx,
436c532df63SYohann                                    const int tidy, const int comp,
4377f823360Sjeremylt                                    const int dim, const int nelem,
4387f823360Sjeremylt                                    const CeedScalar *d_U, CeedScalar *r_U) {
439c532df63SYohann   for (int i = 0; i < Q1D; i++)
440*18d499f1SYohann     r_U[i] = (tidx < Q1D && tidy < Q1D) ?
441*18d499f1SYohann               d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
442*18d499f1SYohann               comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0;
443*18d499f1SYohann   for (int i = Q1D; i < P1D; i++)
444*18d499f1SYohann     r_U[i] = 0.0;
445c532df63SYohann }
446c532df63SYohann 
447ab213215SJeremy L Thompson //------------------------------------------------------------------------------
448ab213215SJeremy L Thompson // Write quadrature point data
449ab213215SJeremy L Thompson //------------------------------------------------------------------------------
450c532df63SYohann inline __device__ void writeQuads3d(const int elem, const int tidx,
451c532df63SYohann                                     const int tidy, const int comp,
4527f823360Sjeremylt                                     const int dim, const int nelem,
4537f823360Sjeremylt                                     const CeedScalar *r_V, CeedScalar *d_V) {
454*18d499f1SYohann   if (tidx < Q1D && tidy < Q1D) {
455c532df63SYohann     for (int i = 0; i < Q1D; i++)
456c532df63SYohann       d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
457c532df63SYohann           dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i];
458c532df63SYohann   }
459*18d499f1SYohann }
460c532df63SYohann 
461ab213215SJeremy L Thompson //------------------------------------------------------------------------------
462ab213215SJeremy L Thompson // 3D tensor contract x
463ab213215SJeremy L Thompson //------------------------------------------------------------------------------
464c532df63SYohann inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
465698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
466*18d499f1SYohann                                    const CeedScalar *U,
467*18d499f1SYohann                                    const CeedScalar *B,
4687f823360Sjeremylt                                    CeedScalar *V) {
469c532df63SYohann   for (int k = 0; k < P1D; ++k) {
470*18d499f1SYohann     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
471c532df63SYohann     __syncthreads();
472c532df63SYohann     V[k] = 0.0;
473*18d499f1SYohann     if (tidx < Q1D && tidy < P1D)
474ab213215SJeremy L Thompson       for (int i = 0; i < P1D; ++i)
475*18d499f1SYohann         V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
476c532df63SYohann     __syncthreads();
477c532df63SYohann   }
478c532df63SYohann }
479c532df63SYohann 
480ab213215SJeremy L Thompson //------------------------------------------------------------------------------
481ab213215SJeremy L Thompson // 3D tensor contract y
482ab213215SJeremy L Thompson //------------------------------------------------------------------------------
483c532df63SYohann inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
484698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
485*18d499f1SYohann                                    const CeedScalar *U,
486*18d499f1SYohann                                    const CeedScalar *B,
4877f823360Sjeremylt                                    CeedScalar *V) {
488c532df63SYohann   for (int k = 0; k < P1D; ++k) {
489*18d499f1SYohann     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
490c532df63SYohann     __syncthreads();
491c532df63SYohann     V[k] = 0.0;
492*18d499f1SYohann     if (tidx < Q1D && tidy < Q1D)
493ab213215SJeremy L Thompson       for (int i = 0; i < P1D; ++i)
494*18d499f1SYohann         V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
495c532df63SYohann     __syncthreads();
496c532df63SYohann   }
497c532df63SYohann }
498c532df63SYohann 
499ab213215SJeremy L Thompson //------------------------------------------------------------------------------
500ab213215SJeremy L Thompson // 3D tensor contract z
501ab213215SJeremy L Thompson //------------------------------------------------------------------------------
502c532df63SYohann inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
503698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
504*18d499f1SYohann                                    const CeedScalar *U,
505*18d499f1SYohann                                    const CeedScalar *B,
5067f823360Sjeremylt                                    CeedScalar *V) {
507c532df63SYohann   for (int k = 0; k < Q1D; ++k) {
508c532df63SYohann     V[k] = 0.0;
509*18d499f1SYohann     if (tidx < Q1D && tidy < Q1D)
510ab213215SJeremy L Thompson       for (int i = 0; i < P1D; ++i)
511ab213215SJeremy L Thompson         V[k] += B[i + k*P1D] * U[i]; // Contract z direction
512c532df63SYohann   }
513*18d499f1SYohann   for (int k = Q1D; k < P1D; ++k)
514*18d499f1SYohann     V[k] = 0.0;
515c532df63SYohann }
516c532df63SYohann 
517ab213215SJeremy L Thompson //------------------------------------------------------------------------------
518ab213215SJeremy L Thompson // 3D transpose tensor contract z
519ab213215SJeremy L Thompson //------------------------------------------------------------------------------
520c532df63SYohann inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
521698ebc35SYohann Dudouit                                             const int tidy, const int tidz,
522*18d499f1SYohann                                             const CeedScalar *U,
523*18d499f1SYohann                                             const CeedScalar *B,
524*18d499f1SYohann                                             CeedScalar *V) {
525*18d499f1SYohann   for (int k = 0; k < P1D; ++k) {
526c532df63SYohann     V[k] = 0.0;
527*18d499f1SYohann     if (tidx < Q1D && tidy < Q1D)
528ab213215SJeremy L Thompson       for (int i = 0; i < Q1D; ++i)
529ab213215SJeremy L Thompson         V[k] += B[k + i*P1D] * U[i]; // Contract z direction
530c532df63SYohann   }
531*18d499f1SYohann   for (int k = P1D; k < Q1D; ++k)
532*18d499f1SYohann     V[k] = 0.0;
533c532df63SYohann }
534c532df63SYohann 
535ab213215SJeremy L Thompson //------------------------------------------------------------------------------
536ab213215SJeremy L Thompson // 3D transpose tensor contract y
537ab213215SJeremy L Thompson //------------------------------------------------------------------------------
538c532df63SYohann inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
539698ebc35SYohann Dudouit                                             const int tidy, const int tidz,
540*18d499f1SYohann                                             const CeedScalar *U,
541*18d499f1SYohann                                             const CeedScalar *B,
542*18d499f1SYohann                                             CeedScalar *V) {
543c532df63SYohann   for (int k = 0; k < P1D; ++k) {
544*18d499f1SYohann     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
545c532df63SYohann     __syncthreads();
546c532df63SYohann     V[k] = 0.0;
547*18d499f1SYohann     if (tidx < Q1D && tidy < P1D)
548ab213215SJeremy L Thompson       for (int i = 0; i < Q1D; ++i)
549*18d499f1SYohann         V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
550c532df63SYohann     __syncthreads();
551c532df63SYohann   }
552c532df63SYohann }
553c532df63SYohann 
554ab213215SJeremy L Thompson //------------------------------------------------------------------------------
555ab213215SJeremy L Thompson // 3D transpose tensor contract x
556ab213215SJeremy L Thompson //------------------------------------------------------------------------------
557c532df63SYohann inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
558698ebc35SYohann Dudouit                                             const int tidy, const int tidz,
559*18d499f1SYohann                                             const CeedScalar *U,
560*18d499f1SYohann                                             const CeedScalar *B,
561*18d499f1SYohann                                             CeedScalar *V) {
562c532df63SYohann   for (int k = 0; k < P1D; ++k) {
563*18d499f1SYohann     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
564c532df63SYohann     __syncthreads();
565c532df63SYohann     V[k] = 0.0;
566*18d499f1SYohann     if (tidx < P1D && tidy < P1D)
567ab213215SJeremy L Thompson       for (int i = 0; i < Q1D; ++i)
568*18d499f1SYohann         V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
569c532df63SYohann     __syncthreads();
570c532df63SYohann   }
571c532df63SYohann }
572c532df63SYohann 
573ab213215SJeremy L Thompson //------------------------------------------------------------------------------
574ab213215SJeremy L Thompson // 3D interpolate to quadrature points
575ab213215SJeremy L Thompson //------------------------------------------------------------------------------
576c532df63SYohann inline __device__ void interp3d(const CeedInt nelem, const int transpose,
5777f823360Sjeremylt                                 const CeedScalar *c_B,
5787f823360Sjeremylt                                 const CeedScalar *__restrict__ d_U,
579c532df63SYohann                                 CeedScalar *__restrict__ d_V,
580c532df63SYohann                                 CeedScalar *slice) {
581*18d499f1SYohann   CeedScalar r_V[T1D];
582*18d499f1SYohann   CeedScalar r_t[T1D];
583c532df63SYohann 
584c532df63SYohann   const int tidx = threadIdx.x;
585c532df63SYohann   const int tidy = threadIdx.y;
586698ebc35SYohann Dudouit   const int tidz = threadIdx.z;
587698ebc35SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
588698ebc35SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
589698ebc35SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
590c532df63SYohann 
591698ebc35SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
592698ebc35SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
593*18d499f1SYohann     for (int i = 0; i < T1D; ++i) {
594c532df63SYohann       r_V[i] = 0.0;
595c532df63SYohann       r_t[i] = 0.0;
596c532df63SYohann     }
597c532df63SYohann     if (!transpose) {
598c532df63SYohann       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
599698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
600698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
601698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
602c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
603c532df63SYohann     } else {
604c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
605698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
606698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
607698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
608c532df63SYohann       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
609c532df63SYohann     }
610c532df63SYohann   }
611c532df63SYohann }
612c532df63SYohann 
613ab213215SJeremy L Thompson //------------------------------------------------------------------------------
614ab213215SJeremy L Thompson // 3D derivatives at quadrature points
615ab213215SJeremy L Thompson //------------------------------------------------------------------------------
616c532df63SYohann inline __device__ void grad3d(const CeedInt nelem, const int transpose,
617c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
6187f823360Sjeremylt                               const CeedScalar *__restrict__ d_U,
6197f823360Sjeremylt                               CeedScalar *__restrict__ d_V,
620c532df63SYohann                               CeedScalar *slice) {
621ab213215SJeremy L Thompson   // Use P1D for one of these
622*18d499f1SYohann   CeedScalar r_U[T1D];
623*18d499f1SYohann   CeedScalar r_V[T1D];
624*18d499f1SYohann   CeedScalar r_t[T1D];
625c532df63SYohann 
626c532df63SYohann   const int tidx = threadIdx.x;
627c532df63SYohann   const int tidy = threadIdx.y;
628698ebc35SYohann Dudouit   const int tidz = threadIdx.z;
629698ebc35SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
630698ebc35SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
631698ebc35SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
632c532df63SYohann   int dim;
633c532df63SYohann 
634698ebc35SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
635698ebc35SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
636*18d499f1SYohann     for (int i = 0; i < T1D; ++i) {
637*18d499f1SYohann       r_U[i] = 0.0;
638*18d499f1SYohann       r_V[i] = 0.0;
639*18d499f1SYohann       r_t[i] = 0.0;
640*18d499f1SYohann     }
641c532df63SYohann     if (!transpose) {
642c532df63SYohann       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
643698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
644698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
645698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
646c532df63SYohann       dim = 0;
647c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
648698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
649698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
650698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
651c532df63SYohann       dim = 1;
652c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
653698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
654698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
655698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
656c532df63SYohann       dim = 2;
657c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
658c532df63SYohann     } else {
659c532df63SYohann       dim = 0;
660c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
661698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
662698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
663698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
664c532df63SYohann       dim = 1;
665c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
666698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
667698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
668698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
669c532df63SYohann       add(r_V, r_t);
670c532df63SYohann       dim = 2;
671c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
672698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
673698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
674698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
675c532df63SYohann       add(r_V, r_t);
676c532df63SYohann       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
677c532df63SYohann     }
678c532df63SYohann   }
679c532df63SYohann }
680c532df63SYohann 
681ab213215SJeremy L Thompson //------------------------------------------------------------------------------
682ab213215SJeremy L Thompson // 3D quadrature weights
683ab213215SJeremy L Thompson //------------------------------------------------------------------------------
684ab213215SJeremy L Thompson __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
685ab213215SJeremy L Thompson                          CeedScalar *w) {
686ab213215SJeremy L Thompson   const int i = threadIdx.x;
687ab213215SJeremy L Thompson   const int j = threadIdx.y;
688ab213215SJeremy L Thompson   const int k = threadIdx.z;
689ab213215SJeremy L Thompson   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
690ab213215SJeremy L Thompson   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
691ab213215SJeremy L Thompson     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
692ab213215SJeremy L Thompson     w[ind] = weight;
693ab213215SJeremy L Thompson   }
694ab213215SJeremy L Thompson }
695ab213215SJeremy L Thompson 
696ab213215SJeremy L Thompson 
697ab213215SJeremy L Thompson //------------------------------------------------------------------------------
698ab213215SJeremy L Thompson // Basis kernels
699ab213215SJeremy L Thompson //------------------------------------------------------------------------------
700ab213215SJeremy L Thompson 
701ab213215SJeremy L Thompson //------------------------------------------------------------------------------
702ab213215SJeremy L Thompson // Interp kernel by dim
703ab213215SJeremy L Thompson //------------------------------------------------------------------------------
704c532df63SYohann extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
7057f823360Sjeremylt                                   const CeedScalar *c_B,
7067f823360Sjeremylt                                   const CeedScalar *__restrict__ d_U,
707c532df63SYohann                                   CeedScalar *__restrict__ d_V) {
708074be161SYohann Dudouit   extern __shared__ double slice[];
709c532df63SYohann   if (BASIS_DIM == 1) {
710c532df63SYohann     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
711c532df63SYohann   } else if (BASIS_DIM == 2) {
712c532df63SYohann     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
713c532df63SYohann   } else if (BASIS_DIM == 3) {
714c532df63SYohann     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
715c532df63SYohann   }
716c532df63SYohann }
717c532df63SYohann 
718ab213215SJeremy L Thompson //------------------------------------------------------------------------------
719ab213215SJeremy L Thompson // Grad kernel by dim
720ab213215SJeremy L Thompson //------------------------------------------------------------------------------
721c532df63SYohann extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
722c532df63SYohann                                 const CeedScalar *c_B, const CeedScalar *c_G,
7237f823360Sjeremylt                                 const CeedScalar *__restrict__ d_U,
7247f823360Sjeremylt                                 CeedScalar *__restrict__ d_V) {
725074be161SYohann Dudouit   extern __shared__ double slice[];
726c532df63SYohann   if (BASIS_DIM == 1) {
727c532df63SYohann     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
728c532df63SYohann   } else if (BASIS_DIM == 2) {
729c532df63SYohann     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
730c532df63SYohann   } else if (BASIS_DIM == 3) {
731c532df63SYohann     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
732c532df63SYohann   }
733c532df63SYohann }
734c532df63SYohann 
735ab213215SJeremy L Thompson //------------------------------------------------------------------------------
736ab213215SJeremy L Thompson // Weight kernels by dim
737ab213215SJeremy L Thompson //------------------------------------------------------------------------------
738c532df63SYohann extern "C" __global__ void weight(const CeedInt nelem,
7397f823360Sjeremylt                                   const CeedScalar *__restrict__ qweight1d,
7407f823360Sjeremylt                                   CeedScalar *__restrict__ v) {
741c532df63SYohann   if (BASIS_DIM == 1) {
742c532df63SYohann     weight1d(nelem, qweight1d, v);
743c532df63SYohann   } else if (BASIS_DIM == 2) {
744c532df63SYohann     weight2d(nelem, qweight1d, v);
745c532df63SYohann   } else if (BASIS_DIM == 3) {
746c532df63SYohann     weight3d(nelem, qweight1d, v);
747c532df63SYohann   }
748c532df63SYohann }
749c532df63SYohann 
750c532df63SYohann );
751cb0b5415Sjeremylt // *INDENT-ON*
752c532df63SYohann 
753ab213215SJeremy L Thompson //------------------------------------------------------------------------------
754ab213215SJeremy L Thompson // Device initalization
755ab213215SJeremy L Thompson //------------------------------------------------------------------------------
756c532df63SYohann int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
757c532df63SYohann                        CeedScalar **c_B);
758c532df63SYohann int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
7597f823360Sjeremylt                            CeedInt Q1d, CeedScalar **c_B_ptr,
7607f823360Sjeremylt                            CeedScalar **c_G_ptr);
761c532df63SYohann 
762ab213215SJeremy L Thompson //------------------------------------------------------------------------------
763ab213215SJeremy L Thompson // Apply basis
764ab213215SJeremy L Thompson //------------------------------------------------------------------------------
765c532df63SYohann int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem,
766c532df63SYohann                                      CeedTransposeMode tmode,
7677f823360Sjeremylt                                      CeedEvalMode emode, CeedVector u,
7687f823360Sjeremylt                                      CeedVector v) {
769c532df63SYohann   int ierr;
770c532df63SYohann   Ceed ceed;
771c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
772c532df63SYohann   Ceed_Cuda_shared *ceed_Cuda;
773777ff853SJeremy L Thompson   CeedGetData(ceed, &ceed_Cuda); CeedChk(ierr);
774c532df63SYohann   CeedBasis_Cuda_shared *data;
775777ff853SJeremy L Thompson   CeedBasisGetData(basis, &data); CeedChk(ierr);
776c532df63SYohann   const CeedInt transpose = tmode == CEED_TRANSPOSE;
7774247ecf3SYohann Dudouit   CeedInt dim, ncomp;
778074be161SYohann Dudouit   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
7794247ecf3SYohann Dudouit   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
780c532df63SYohann 
781ab213215SJeremy L Thompson   // Read vectors
782c532df63SYohann   const CeedScalar *d_u;
783c532df63SYohann   CeedScalar *d_v;
784c532df63SYohann   if (emode != CEED_EVAL_WEIGHT) {
785c532df63SYohann     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
786c532df63SYohann   }
787c532df63SYohann   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
788c532df63SYohann 
789ab213215SJeremy L Thompson   // Clear v for transpose mode
790c532df63SYohann   if (tmode == CEED_TRANSPOSE) {
791c532df63SYohann     CeedInt length;
792c532df63SYohann     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
793c532df63SYohann     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
794c532df63SYohann   }
795ab213215SJeremy L Thompson 
796ab213215SJeremy L Thompson   // Apply basis operation
797ab213215SJeremy L Thompson   switch (emode) {
798ab213215SJeremy L Thompson   case CEED_EVAL_INTERP: {
799c532df63SYohann     CeedInt P1d, Q1d;
800c532df63SYohann     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
801c532df63SYohann     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
802*18d499f1SYohann     CeedInt thread1d = CeedIntMax(Q1d, P1d);
803c532df63SYohann     ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
804c532df63SYohann     CeedChk(ierr);
805cb0b5415Sjeremylt     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
806ccf0fe6fSjeremylt                           &d_u, &d_v
807ccf0fe6fSjeremylt                          };
8084d537eeaSYohann     if (dim == 1) {
809d94769d2SYohann Dudouit       CeedInt elemsPerBlock = 32;
8104d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8114d537eeaSYohann                                              ? 1 : 0 );
812*18d499f1SYohann       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
813*18d499f1SYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, 1,
8144d537eeaSYohann                                         elemsPerBlock, sharedMem,
815ab213215SJeremy L Thompson                                         interpargs); CeedChk(ierr);
816074be161SYohann Dudouit     } else if (dim == 2) {
8174247ecf3SYohann Dudouit       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
8180f70cdf6SJeremy L Thompson       // elemsPerBlock must be at least 1
819*18d499f1SYohann       CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1);
8204d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8214d537eeaSYohann                                              ? 1 : 0 );
822*18d499f1SYohann       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
823*18d499f1SYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, thread1d,
8244d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
825ab213215SJeremy L Thompson                                         interpargs); CeedChk(ierr);
826074be161SYohann Dudouit     } else if (dim == 3) {
8273f63d318SYohann Dudouit       CeedInt elemsPerBlock = 1;
8284d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8294d537eeaSYohann                                              ? 1 : 0 );
830*18d499f1SYohann       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
831*18d499f1SYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, thread1d,
8324d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
833ab213215SJeremy L Thompson                                         interpargs); CeedChk(ierr);
834074be161SYohann Dudouit     }
835ab213215SJeremy L Thompson   } break;
836ab213215SJeremy L Thompson   case CEED_EVAL_GRAD: {
837c532df63SYohann     CeedInt P1d, Q1d;
838c532df63SYohann     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
839c532df63SYohann     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
840*18d499f1SYohann     CeedInt thread1d = CeedIntMax(Q1d, P1d);
841c532df63SYohann     ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
842c532df63SYohann                                   Q1d, &data->c_B, &data->c_G);
843c532df63SYohann     CeedChk(ierr);
844cb0b5415Sjeremylt     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
845ccf0fe6fSjeremylt                         &data->c_G, &d_u, &d_v
846ccf0fe6fSjeremylt                        };
8474d537eeaSYohann     if (dim == 1) {
848d94769d2SYohann Dudouit       CeedInt elemsPerBlock = 32;
8494d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8504d537eeaSYohann                                              ? 1 : 0 );
851*18d499f1SYohann       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
852*18d499f1SYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, 1,
853ab213215SJeremy L Thompson                                         elemsPerBlock, sharedMem, gradargs);
854c532df63SYohann       CeedChk(ierr);
855074be161SYohann Dudouit     } else if (dim == 2) {
8564247ecf3SYohann Dudouit       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
8570f70cdf6SJeremy L Thompson       // elemsPerBlock must be at least 1
858*18d499f1SYohann       CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1);
8594d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8604d537eeaSYohann                                              ? 1 : 0 );
861*18d499f1SYohann       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
862*18d499f1SYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, thread1d,
8634d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
864ab213215SJeremy L Thompson                                         gradargs); CeedChk(ierr);
865074be161SYohann Dudouit     } else if (dim == 3) {
8663f63d318SYohann Dudouit       CeedInt elemsPerBlock = 1;
8674d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8684d537eeaSYohann                                              ? 1 : 0 );
869*18d499f1SYohann       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
870*18d499f1SYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, thread1d,
8714d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
872ab213215SJeremy L Thompson                                         gradargs); CeedChk(ierr);
873074be161SYohann Dudouit     }
874ab213215SJeremy L Thompson   } break;
875ab213215SJeremy L Thompson   case CEED_EVAL_WEIGHT: {
876074be161SYohann Dudouit     CeedInt Q1d;
877074be161SYohann Dudouit     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
878c532df63SYohann     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
879074be161SYohann Dudouit     if (dim == 1) {
880074be161SYohann Dudouit       const CeedInt elemsPerBlock = 32/Q1d;
8814d537eeaSYohann       const CeedInt gridsize = nelem/elemsPerBlock + ( (
8824d537eeaSYohann                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
8837f823360Sjeremylt       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d,
8847f823360Sjeremylt                                   elemsPerBlock, 1, weightargs);
8851226057fSYohann Dudouit       CeedChk(ierr);
886074be161SYohann Dudouit     } else if (dim == 2) {
887717ff8a3SYohann Dudouit       const CeedInt optElems = 32/(Q1d*Q1d);
888717ff8a3SYohann Dudouit       const CeedInt elemsPerBlock = optElems>0?optElems:1;
8894d537eeaSYohann       const CeedInt gridsize = nelem/elemsPerBlock + ( (
8904d537eeaSYohann                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
8914d537eeaSYohann       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d,
8924d537eeaSYohann                                   elemsPerBlock, weightargs);
8931226057fSYohann Dudouit       CeedChk(ierr);
894074be161SYohann Dudouit     } else if (dim == 3) {
895074be161SYohann Dudouit       const CeedInt gridsize = nelem;
8964d537eeaSYohann       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
8974d537eeaSYohann                                   weightargs);
8981226057fSYohann Dudouit       CeedChk(ierr);
899074be161SYohann Dudouit     }
900ab213215SJeremy L Thompson   } break;
901ab213215SJeremy L Thompson   // LCOV_EXCL_START
902ab213215SJeremy L Thompson   // Evaluate the divergence to/from the quadrature points
903ab213215SJeremy L Thompson   case CEED_EVAL_DIV:
904ab213215SJeremy L Thompson     return CeedError(ceed, 1, "CEED_EVAL_DIV not supported");
905ab213215SJeremy L Thompson   // Evaluate the curl to/from the quadrature points
906ab213215SJeremy L Thompson   case CEED_EVAL_CURL:
907ab213215SJeremy L Thompson     return CeedError(ceed, 1, "CEED_EVAL_CURL not supported");
908ab213215SJeremy L Thompson   // Take no action, BasisApply should not have been called
909ab213215SJeremy L Thompson   case CEED_EVAL_NONE:
910ab213215SJeremy L Thompson     return CeedError(ceed, 1,
911ab213215SJeremy L Thompson                      "CEED_EVAL_NONE does not make sense in this context");
912ab213215SJeremy L Thompson     // LCOV_EXCL_STOP
913c532df63SYohann   }
914c532df63SYohann 
915ab213215SJeremy L Thompson   // Restore vectors
916c532df63SYohann   if (emode != CEED_EVAL_WEIGHT) {
917c532df63SYohann     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
918c532df63SYohann   }
919c532df63SYohann   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
920c532df63SYohann   return 0;
921c532df63SYohann }
922c532df63SYohann 
923ab213215SJeremy L Thompson //------------------------------------------------------------------------------
924ab213215SJeremy L Thompson // Destroy basis
925ab213215SJeremy L Thompson //------------------------------------------------------------------------------
926c532df63SYohann static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
927c532df63SYohann   int ierr;
928c532df63SYohann   Ceed ceed;
929c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
930c532df63SYohann 
931c532df63SYohann   CeedBasis_Cuda_shared *data;
932777ff853SJeremy L Thompson   ierr = CeedBasisGetData(basis, &data); CeedChk(ierr);
933c532df63SYohann 
934c532df63SYohann   CeedChk_Cu(ceed, cuModuleUnload(data->module));
935c532df63SYohann 
936c532df63SYohann   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr);
937c532df63SYohann   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr);
938c532df63SYohann   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr);
9391958eb7cSJeremy L Thompson   ierr = cudaFree(data->d_collograd1d); CeedChk_Cu(ceed, ierr);
940c532df63SYohann 
941c532df63SYohann   ierr = CeedFree(&data); CeedChk(ierr);
942c532df63SYohann 
943c532df63SYohann   return 0;
944c532df63SYohann }
945c532df63SYohann 
946ab213215SJeremy L Thompson //------------------------------------------------------------------------------
947ab213215SJeremy L Thompson // Create tensor basis
948ab213215SJeremy L Thompson //------------------------------------------------------------------------------
949c532df63SYohann int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
950c532df63SYohann                                         const CeedScalar *interp1d,
951c532df63SYohann                                         const CeedScalar *grad1d,
952c532df63SYohann                                         const CeedScalar *qref1d,
953c532df63SYohann                                         const CeedScalar *qweight1d,
954c532df63SYohann                                         CeedBasis basis) {
955c532df63SYohann   int ierr;
956c532df63SYohann   Ceed ceed;
957c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
958c532df63SYohann   CeedBasis_Cuda_shared *data;
959c532df63SYohann   ierr = CeedCalloc(1, &data); CeedChk(ierr);
960c532df63SYohann 
961ab213215SJeremy L Thompson   // Copy basis data to GPU
962c532df63SYohann   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
963c532df63SYohann   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr);
964c532df63SYohann   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
965c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
966c532df63SYohann 
967c532df63SYohann   const CeedInt iBytes = qBytes * P1d;
968c532df63SYohann   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr);
969c532df63SYohann   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
970c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
971c532df63SYohann 
972c532df63SYohann   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr);
973c532df63SYohann   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
974c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
975c532df63SYohann 
976ab213215SJeremy L Thompson   // Compute collocated gradient and copy to GPU
977ac421f39SYohann   data->d_collograd1d = NULL;
978ac421f39SYohann   if (dim == 3 && Q1d >= P1d) {
979ac421f39SYohann     CeedScalar *collograd1d;
980ac421f39SYohann     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChk(ierr);
981ac421f39SYohann     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChk(ierr);
982ac421f39SYohann     ierr = cudaMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
983ac421f39SYohann     CeedChk_Cu(ceed, ierr);
984ac421f39SYohann     ierr = cudaMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
985ac421f39SYohann                       cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
9861958eb7cSJeremy L Thompson     ierr = CeedFree(&collograd1d); CeedChk(ierr);
987ac421f39SYohann   }
988ac421f39SYohann 
989ab213215SJeremy L Thompson   // Compile basis kernels
990c532df63SYohann   CeedInt ncomp;
991c532df63SYohann   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
992*18d499f1SYohann   ierr = CeedCompileCuda(ceed, kernelsShared, &data->module, 8,
993c532df63SYohann                          "Q1D", Q1d,
994c532df63SYohann                          "P1D", P1d,
995*18d499f1SYohann                          "T1D", CeedIntMax(Q1d, P1d),
996c532df63SYohann                          "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
997c532df63SYohann                              Q1d : P1d, dim),
998c532df63SYohann                          "BASIS_DIM", dim,
999c532df63SYohann                          "BASIS_NCOMP", ncomp,
1000c532df63SYohann                          "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
1001c532df63SYohann                          "BASIS_NQPT", CeedIntPow(Q1d, dim)
1002c532df63SYohann                         ); CeedChk(ierr);
10034a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp);
1004c532df63SYohann   CeedChk(ierr);
10054a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad);
1006c532df63SYohann   CeedChk(ierr);
10074a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight);
1008c532df63SYohann   CeedChk(ierr);
1009c532df63SYohann 
1010777ff853SJeremy L Thompson   ierr = CeedBasisSetData(basis, data); CeedChk(ierr);
1011ab213215SJeremy L Thompson 
1012ab213215SJeremy L Thompson   // Register backend functions
1013c532df63SYohann   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
1014c532df63SYohann                                 CeedBasisApplyTensor_Cuda_shared);
1015c532df63SYohann   CeedChk(ierr);
1016c532df63SYohann   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
1017ab213215SJeremy L Thompson                                 CeedBasisDestroy_Cuda_shared); CeedChk(ierr);
1018c532df63SYohann   return 0;
1019c532df63SYohann }
1020ab213215SJeremy L Thompson //------------------------------------------------------------------------------
1021