xref: /libCEED/rust/libceed-sys/c-src/backends/cuda-shared/ceed-cuda-shared-basis.c (revision e15f9bd09af0280c89b79924fa9af7dd2e3e30be)
1c532df63SYohann // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2c532df63SYohann // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3c532df63SYohann // All Rights reserved. See files LICENSE and NOTICE for details.
4c532df63SYohann //
5c532df63SYohann // This file is part of CEED, a collection of benchmarks, miniapps, software
6c532df63SYohann // libraries and APIs for efficient high-order finite element and spectral
7c532df63SYohann // element discretizations for exascale applications. For more information and
8c532df63SYohann // source code availability see http://github.com/ceed.
9c532df63SYohann //
10c532df63SYohann // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11c532df63SYohann // a collaborative effort of two U.S. Department of Energy organizations (Office
12c532df63SYohann // of Science and the National Nuclear Security Administration) responsible for
13c532df63SYohann // the planning and preparation of a capable exascale ecosystem, including
14c532df63SYohann // software, applications, hardware, advanced system engineering and early
15c532df63SYohann // testbed platforms, in support of the nation's exascale computing imperative.
16c532df63SYohann 
173d576824SJeremy L Thompson #include <ceed.h>
183d576824SJeremy L Thompson #include <ceed-backend.h>
193d576824SJeremy L Thompson #include <cuda.h>
203d576824SJeremy L Thompson #include <cuda_runtime.h>
213d576824SJeremy L Thompson #include <stddef.h>
22c532df63SYohann #include "ceed-cuda-shared.h"
233d576824SJeremy L Thompson #include "../cuda/ceed-cuda.h"
24c532df63SYohann 
25ab213215SJeremy L Thompson //------------------------------------------------------------------------------
26ab213215SJeremy L Thompson // Shared mem kernels
27ab213215SJeremy L Thompson //------------------------------------------------------------------------------
28cb0b5415Sjeremylt // *INDENT-OFF*
29c532df63SYohann static const char *kernelsShared = QUOTE(
30c532df63SYohann 
31ab213215SJeremy L Thompson //------------------------------------------------------------------------------
32ab213215SJeremy L Thompson // Sum input into output
33ab213215SJeremy L Thompson //------------------------------------------------------------------------------
34c532df63SYohann inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
3518d499f1SYohann   for (int i = 0; i < P1D; i++)
36c532df63SYohann     r_V[i] += r_U[i];
37c532df63SYohann }
38c532df63SYohann 
39ab213215SJeremy L Thompson //------------------------------------------------------------------------------
40ab213215SJeremy L Thompson // 1D
41ab213215SJeremy L Thompson //------------------------------------------------------------------------------
42c532df63SYohann 
43ab213215SJeremy L Thompson //------------------------------------------------------------------------------
44ab213215SJeremy L Thompson // Read DoFs
45ab213215SJeremy L Thompson //------------------------------------------------------------------------------
46c532df63SYohann inline __device__ void readDofs1d(const int elem, const int tidx,
47d94769d2SYohann Dudouit                                   const int tidy, const int tidz,const int comp,
487f823360Sjeremylt                                   const int nelem, const CeedScalar *d_U,
497f823360Sjeremylt                                   CeedScalar *slice) {
50c532df63SYohann   for (int i = 0; i < P1D; i++)
5118d499f1SYohann     slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem];
52c532df63SYohann   for (int i = P1D; i < Q1D; i++)
5318d499f1SYohann     slice[i + tidz*T1D] = 0.0;
54c532df63SYohann }
55c532df63SYohann 
56ab213215SJeremy L Thompson //------------------------------------------------------------------------------
57ab213215SJeremy L Thompson // Write DoFs
58ab213215SJeremy L Thompson //------------------------------------------------------------------------------
59c532df63SYohann inline __device__ void writeDofs1d(const int elem, const int tidx,
60c532df63SYohann                                    const int tidy, const int comp,
61288c0443SJeremy L Thompson                                    const int nelem, const CeedScalar &r_V,
62288c0443SJeremy L Thompson                                    CeedScalar *d_V) {
63ab213215SJeremy L Thompson   if (tidx<P1D)
6449fd234cSJeremy L Thompson     d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V;
65c532df63SYohann }
66c532df63SYohann 
67ab213215SJeremy L Thompson //------------------------------------------------------------------------------
68ab213215SJeremy L Thompson // Read quadrature point data
69ab213215SJeremy L Thompson //------------------------------------------------------------------------------
70c532df63SYohann inline __device__ void readQuads1d(const int elem, const int tidx,
71d94769d2SYohann Dudouit                                    const int tidy, const int tidz, const int comp,
72288c0443SJeremy L Thompson                                    const int dim, const int nelem,
73288c0443SJeremy L Thompson                                    const CeedScalar *d_U, CeedScalar *slice) {
74c532df63SYohann   for (int i = 0; i < Q1D; i++)
7518d499f1SYohann     slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
764d537eeaSYohann                             dim*BASIS_NCOMP*nelem*Q1D];
7718d499f1SYohann   for (int i = Q1D; i < P1D; i++)
7818d499f1SYohann     slice[i + tidz*T1D] = 0.0;
79c532df63SYohann }
80c532df63SYohann 
81ab213215SJeremy L Thompson //------------------------------------------------------------------------------
82ab213215SJeremy L Thompson // Write quadrature point data
83ab213215SJeremy L Thompson //------------------------------------------------------------------------------
84c532df63SYohann inline __device__ void writeQuads1d(const int elem, const int tidx,
85c532df63SYohann                                     const int tidy, const int comp,
86288c0443SJeremy L Thompson                                     const int dim, const int nelem,
87288c0443SJeremy L Thompson                                     const CeedScalar &r_V, CeedScalar *d_V) {
8818d499f1SYohann   if (tidx<Q1D)
89c532df63SYohann     d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
90c532df63SYohann }
91c532df63SYohann 
92ab213215SJeremy L Thompson //------------------------------------------------------------------------------
93ab213215SJeremy L Thompson // 1D tensor contraction
94ab213215SJeremy L Thompson //------------------------------------------------------------------------------
95c532df63SYohann inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
96d94769d2SYohann Dudouit                                    const int tidy, const int tidz,
97288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
98288c0443SJeremy L Thompson                                    CeedScalar &V) {
99c532df63SYohann   V = 0.0;
100ab213215SJeremy L Thompson   for (int i = 0; i < P1D; ++i)
10118d499f1SYohann     V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction
102c532df63SYohann }
103c532df63SYohann 
104ab213215SJeremy L Thompson //------------------------------------------------------------------------------
105ab213215SJeremy L Thompson // 1D transpose tensor contraction
106ab213215SJeremy L Thompson //------------------------------------------------------------------------------
107c532df63SYohann inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
108d94769d2SYohann Dudouit     const int tidy, const int tidz,
109c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
110c532df63SYohann   V = 0.0;
111ab213215SJeremy L Thompson   for (int i = 0; i < Q1D; ++i)
11218d499f1SYohann     V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction
113c532df63SYohann }
114c532df63SYohann 
115ab213215SJeremy L Thompson //------------------------------------------------------------------------------
116ab213215SJeremy L Thompson // 1D interpolate to quadrature points
117ab213215SJeremy L Thompson //------------------------------------------------------------------------------
118c532df63SYohann inline __device__ void interp1d(const CeedInt nelem, const int transpose,
119288c0443SJeremy L Thompson                                 const CeedScalar *c_B,
120288c0443SJeremy L Thompson                                 const CeedScalar *__restrict__ d_U,
121c532df63SYohann                                 CeedScalar *__restrict__ d_V,
122c532df63SYohann                                 CeedScalar *slice) {
123c532df63SYohann   CeedScalar r_V;
124c532df63SYohann   CeedScalar r_t;
125c532df63SYohann 
126c532df63SYohann   const int tidx = threadIdx.x;
127c532df63SYohann   const int tidy = threadIdx.y;
128d94769d2SYohann Dudouit   const int tidz = threadIdx.z;
129c532df63SYohann 
130c532df63SYohann 
131c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
132c532df63SYohann        elem += gridDim.x*blockDim.z) {
133c532df63SYohann     for (int comp = 0; comp < BASIS_NCOMP; comp++) {
134c532df63SYohann       if (!transpose) {
135d94769d2SYohann Dudouit         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
136d94769d2SYohann Dudouit         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
137c532df63SYohann         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
138c532df63SYohann       } else {
139d94769d2SYohann Dudouit         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
140d94769d2SYohann Dudouit         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
141c532df63SYohann         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
142c532df63SYohann       }
143c532df63SYohann     }
144c532df63SYohann   }
145c532df63SYohann }
146c532df63SYohann 
147ab213215SJeremy L Thompson //------------------------------------------------------------------------------
148ab213215SJeremy L Thompson // 1D derivatives at quadrature points
149ab213215SJeremy L Thompson //------------------------------------------------------------------------------
150c532df63SYohann inline __device__ void grad1d(const CeedInt nelem, const int transpose,
151c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
152288c0443SJeremy L Thompson                               const CeedScalar *__restrict__ d_U,
153288c0443SJeremy L Thompson                               CeedScalar *__restrict__ d_V,
154c532df63SYohann                               CeedScalar *slice) {
155c532df63SYohann   CeedScalar r_U;
156c532df63SYohann   CeedScalar r_V;
157c532df63SYohann 
158c532df63SYohann   const int tidx = threadIdx.x;
159d94769d2SYohann Dudouit   const int tidy = threadIdx.y;
160d94769d2SYohann Dudouit   const int tidz = threadIdx.z;
161c532df63SYohann   int dim;
162c532df63SYohann 
163c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
164c532df63SYohann        elem += gridDim.x*blockDim.z) {
165c532df63SYohann     for(int comp = 0; comp < BASIS_NCOMP; comp++) {
166c532df63SYohann       if (!transpose) {
167d94769d2SYohann Dudouit         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
168d94769d2SYohann Dudouit         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
169c532df63SYohann         dim = 0;
170c532df63SYohann         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
171c532df63SYohann       } else {
172c532df63SYohann         dim = 0;
173d94769d2SYohann Dudouit         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
174d94769d2SYohann Dudouit         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
175c532df63SYohann         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
176c532df63SYohann       }
177c532df63SYohann     }
178c532df63SYohann   }
179c532df63SYohann }
180c532df63SYohann 
181ab213215SJeremy L Thompson //------------------------------------------------------------------------------
182ab213215SJeremy L Thompson // 1D Quadrature weights
183ab213215SJeremy L Thompson //------------------------------------------------------------------------------
184ab213215SJeremy L Thompson __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
185ab213215SJeremy L Thompson                          CeedScalar *w) {
186ab213215SJeremy L Thompson   const int tid = threadIdx.x;
187ab213215SJeremy L Thompson   const CeedScalar weight = qweight1d[tid];
188ab213215SJeremy L Thompson   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
189ab213215SJeremy L Thompson        elem += gridDim.x*blockDim.y) {
190ab213215SJeremy L Thompson     const int ind = elem*Q1D + tid;
191ab213215SJeremy L Thompson     w[ind] = weight;
192ab213215SJeremy L Thompson   }
193ab213215SJeremy L Thompson }
194ab213215SJeremy L Thompson 
195ab213215SJeremy L Thompson //------------------------------------------------------------------------------
196ab213215SJeremy L Thompson // 2D
197ab213215SJeremy L Thompson //------------------------------------------------------------------------------
198ab213215SJeremy L Thompson 
199ab213215SJeremy L Thompson //------------------------------------------------------------------------------
200ab213215SJeremy L Thompson // Read DoFs
201ab213215SJeremy L Thompson //------------------------------------------------------------------------------
202c532df63SYohann inline __device__ void readDofs2d(const int elem, const int tidx,
203c532df63SYohann                                   const int tidy, const int comp,
204288c0443SJeremy L Thompson                                   const int nelem, const CeedScalar *d_U,
205288c0443SJeremy L Thompson                                   CeedScalar &U) {
20649fd234cSJeremy L Thompson   U = (tidx<P1D && tidy<P1D) ?
20749fd234cSJeremy L Thompson       d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0;
208c532df63SYohann }
209c532df63SYohann 
210ab213215SJeremy L Thompson //------------------------------------------------------------------------------
211ab213215SJeremy L Thompson // Write DoFs
212ab213215SJeremy L Thompson //------------------------------------------------------------------------------
213c532df63SYohann inline __device__ void writeDofs2d(const int elem, const int tidx,
214c532df63SYohann                                    const int tidy, const int comp,
215288c0443SJeremy L Thompson                                    const int nelem, const CeedScalar &r_V,
216288c0443SJeremy L Thompson                                    CeedScalar *d_V) {
217ab213215SJeremy L Thompson   if (tidx<P1D && tidy<P1D)
21849fd234cSJeremy L Thompson     d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V;
219c532df63SYohann }
220c532df63SYohann 
221ab213215SJeremy L Thompson //------------------------------------------------------------------------------
222ab213215SJeremy L Thompson // Read quadrature point data
223ab213215SJeremy L Thompson //------------------------------------------------------------------------------
224c532df63SYohann inline __device__ void readQuads2d(const int elem, const int tidx,
225c532df63SYohann                                    const int tidy, const int comp,
226288c0443SJeremy L Thompson                                    const int dim, const int nelem,
227288c0443SJeremy L Thompson                                    const CeedScalar *d_U, CeedScalar &U ) {
22818d499f1SYohann   U = (tidx<Q1D && tidy<Q1D) ?
22918d499f1SYohann       d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
23018d499f1SYohann       dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0;
231c532df63SYohann }
232c532df63SYohann 
233ab213215SJeremy L Thompson //------------------------------------------------------------------------------
234ab213215SJeremy L Thompson // Write quadrature point data
235ab213215SJeremy L Thompson //------------------------------------------------------------------------------
236c532df63SYohann inline __device__ void writeQuads2d(const int elem, const int tidx,
237c532df63SYohann                                     const int tidy, const int comp,
238288c0443SJeremy L Thompson                                     const int dim, const int nelem,
239288c0443SJeremy L Thompson                                     const CeedScalar &r_V, CeedScalar *d_V) {
24018d499f1SYohann   if (tidx<Q1D && tidy<Q1D)
241c532df63SYohann     d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
242c532df63SYohann     dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V;
243c532df63SYohann }
244c532df63SYohann 
245ab213215SJeremy L Thompson //------------------------------------------------------------------------------
246ab213215SJeremy L Thompson // 2D tensor contraction x
247ab213215SJeremy L Thompson //------------------------------------------------------------------------------
248c532df63SYohann inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
2494247ecf3SYohann Dudouit                                    const int tidy, const int tidz,
250288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
251288c0443SJeremy L Thompson                                    CeedScalar &V) {
25218d499f1SYohann   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
253c532df63SYohann   __syncthreads();
254c532df63SYohann   V = 0.0;
25518d499f1SYohann   if (tidx < Q1D)
256ab213215SJeremy L Thompson     for (int i = 0; i < P1D; ++i)
25718d499f1SYohann       V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
258c532df63SYohann   __syncthreads();
259c532df63SYohann }
260c532df63SYohann 
261ab213215SJeremy L Thompson //------------------------------------------------------------------------------
262ab213215SJeremy L Thompson // 2D tensor contraction y
263ab213215SJeremy L Thompson //------------------------------------------------------------------------------
264c532df63SYohann inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
2654247ecf3SYohann Dudouit                                    const int tidy, const int tidz,
266288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
267288c0443SJeremy L Thompson                                    CeedScalar &V) {
26818d499f1SYohann   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
269c532df63SYohann   __syncthreads();
270c532df63SYohann   V = 0.0;
27118d499f1SYohann   if (tidy < Q1D)
272ab213215SJeremy L Thompson     for (int i = 0; i < P1D; ++i)
27318d499f1SYohann       V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
274c532df63SYohann   __syncthreads();
275c532df63SYohann }
276c532df63SYohann 
277ab213215SJeremy L Thompson //------------------------------------------------------------------------------
278ab213215SJeremy L Thompson // 2D transpose tensor contraction y
279ab213215SJeremy L Thompson //------------------------------------------------------------------------------
280c532df63SYohann inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
2814247ecf3SYohann Dudouit     const int tidy, const int tidz,
282c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
28318d499f1SYohann   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
284c532df63SYohann   __syncthreads();
285c532df63SYohann   V = 0.0;
286ab213215SJeremy L Thompson   if (tidy < P1D)
287ab213215SJeremy L Thompson     for (int i = 0; i < Q1D; ++i)
28818d499f1SYohann       V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
289c532df63SYohann   __syncthreads();
290c532df63SYohann }
291c532df63SYohann 
292ab213215SJeremy L Thompson //------------------------------------------------------------------------------
293ab213215SJeremy L Thompson // 2D transpose tensor contraction x
294ab213215SJeremy L Thompson //------------------------------------------------------------------------------
295c532df63SYohann inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
2964247ecf3SYohann Dudouit     const int tidy, const int tidz,
297c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
29818d499f1SYohann   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
299c532df63SYohann   __syncthreads();
300c532df63SYohann   V = 0.0;
301ab213215SJeremy L Thompson   if (tidx < P1D)
302ab213215SJeremy L Thompson     for (int i = 0; i < Q1D; ++i)
30318d499f1SYohann       V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
304c532df63SYohann   __syncthreads();
305c532df63SYohann }
306c532df63SYohann 
307ab213215SJeremy L Thompson //------------------------------------------------------------------------------
308ab213215SJeremy L Thompson // 2D interpolate to quadrature points
309ab213215SJeremy L Thompson //------------------------------------------------------------------------------
310c532df63SYohann inline __device__ void interp2d(const CeedInt nelem, const int transpose,
311288c0443SJeremy L Thompson                                 const CeedScalar *c_B,
312288c0443SJeremy L Thompson                                 const CeedScalar *__restrict__ d_U,
313c532df63SYohann                                 CeedScalar *__restrict__ d_V,
314c532df63SYohann                                 CeedScalar *slice) {
315c532df63SYohann   CeedScalar r_V;
316c532df63SYohann   CeedScalar r_t;
317c532df63SYohann 
318c532df63SYohann   const int tidx = threadIdx.x;
319c532df63SYohann   const int tidy = threadIdx.y;
3204247ecf3SYohann Dudouit   const int tidz = threadIdx.z;
3214247ecf3SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
3224247ecf3SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
3234247ecf3SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
324c532df63SYohann 
3254247ecf3SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
3264247ecf3SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
3274247ecf3SYohann Dudouit     const int comp = tidz%BASIS_NCOMP;
328c532df63SYohann     r_V = 0.0;
329c532df63SYohann     r_t = 0.0;
330c532df63SYohann     if (!transpose) {
331c532df63SYohann       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
3324247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
3334247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
334c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
335c532df63SYohann     } else {
336c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
3374247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
3384247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
339c532df63SYohann       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
340c532df63SYohann     }
341c532df63SYohann   }
342c532df63SYohann }
343c532df63SYohann 
344ab213215SJeremy L Thompson //------------------------------------------------------------------------------
345ab213215SJeremy L Thompson // 2D derivatives at quadrature points
346ab213215SJeremy L Thompson //------------------------------------------------------------------------------
347c532df63SYohann inline __device__ void grad2d(const CeedInt nelem, const int transpose,
348c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
3497f823360Sjeremylt                               const CeedScalar *__restrict__ d_U,
3507f823360Sjeremylt                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
351c532df63SYohann   CeedScalar r_U;
352c532df63SYohann   CeedScalar r_V;
353c532df63SYohann   CeedScalar r_t;
354c532df63SYohann 
355c532df63SYohann   const int tidx = threadIdx.x;
356c532df63SYohann   const int tidy = threadIdx.y;
3574247ecf3SYohann Dudouit   const int tidz = threadIdx.z;
3584247ecf3SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
3594247ecf3SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
3604247ecf3SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
361c532df63SYohann   int dim;
362c532df63SYohann 
3634247ecf3SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
3644247ecf3SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
365c532df63SYohann     if (!transpose) {
366c532df63SYohann       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
3674247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
3684247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
369c532df63SYohann       dim = 0;
370c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
3714247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
3724247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
373c532df63SYohann       dim = 1;
374c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
375c532df63SYohann     } else {
376c532df63SYohann       dim = 0;
377c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
3784247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
3794247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
380c532df63SYohann       dim = 1;
381c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
3824247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
3834247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
384c532df63SYohann       r_V += r_U;
385c532df63SYohann       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
386c532df63SYohann     }
387c532df63SYohann   }
388c532df63SYohann }
389c532df63SYohann 
390ab213215SJeremy L Thompson //------------------------------------------------------------------------------
391ab213215SJeremy L Thompson // 2D quadrature weights
392ab213215SJeremy L Thompson //------------------------------------------------------------------------------
393ab213215SJeremy L Thompson __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
394ab213215SJeremy L Thompson                          CeedScalar *w) {
395ab213215SJeremy L Thompson   const int i = threadIdx.x;
396ab213215SJeremy L Thompson   const int j = threadIdx.y;
397ab213215SJeremy L Thompson   const CeedScalar weight = qweight1d[i]*qweight1d[j];
398ab213215SJeremy L Thompson   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
399ab213215SJeremy L Thompson        elem += gridDim.x*blockDim.z) {
400ab213215SJeremy L Thompson     const int ind = elem*Q1D*Q1D + i + j*Q1D;
401ab213215SJeremy L Thompson     w[ind] = weight;
402ab213215SJeremy L Thompson   }
403ab213215SJeremy L Thompson }
404ab213215SJeremy L Thompson 
405ab213215SJeremy L Thompson //------------------------------------------------------------------------------
406ab213215SJeremy L Thompson // 3D
407ab213215SJeremy L Thompson //------------------------------------------------------------------------------
408ab213215SJeremy L Thompson 
409ab213215SJeremy L Thompson //------------------------------------------------------------------------------
410ab213215SJeremy L Thompson // Read DoFs
411ab213215SJeremy L Thompson //------------------------------------------------------------------------------
412c532df63SYohann inline __device__ void readDofs3d(const int elem, const int tidx,
413c532df63SYohann                                   const int tidy, const int comp,
4147f823360Sjeremylt                                   const int nelem, const CeedScalar *d_U,
4157f823360Sjeremylt                                   CeedScalar *r_U) {
416c532df63SYohann   for (int i = 0; i < P1D; i++)
417ab213215SJeremy L Thompson     r_U[i] = (tidx < P1D && tidy < P1D) ?
41849fd234cSJeremy L Thompson               d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
41949fd234cSJeremy L Thompson                   comp*P1D*P1D*P1D*nelem] : 0.0;
420c532df63SYohann   for (int i = P1D; i < Q1D; i++)
421c532df63SYohann     r_U[i] = 0.0;
422c532df63SYohann }
423c532df63SYohann 
424ab213215SJeremy L Thompson //------------------------------------------------------------------------------
42549fd234cSJeremy L Thompson // Write DoFs
42649fd234cSJeremy L Thompson //------------------------------------------------------------------------------
42749fd234cSJeremy L Thompson inline __device__ void writeDofs3d(const int elem, const int tidx,
42849fd234cSJeremy L Thompson                                    const int tidy, const int comp,
42949fd234cSJeremy L Thompson                                    const int nelem, const CeedScalar *r_V,
43049fd234cSJeremy L Thompson                                    CeedScalar *d_V) {
43149fd234cSJeremy L Thompson   if (tidx < P1D && tidy < P1D) {
43249fd234cSJeremy L Thompson     for (int i = 0; i < P1D; i++)
43349fd234cSJeremy L Thompson       d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
43449fd234cSJeremy L Thompson           comp*P1D*P1D*P1D*nelem] = r_V[i];
43549fd234cSJeremy L Thompson   }
43649fd234cSJeremy L Thompson }
43749fd234cSJeremy L Thompson 
43849fd234cSJeremy L Thompson //------------------------------------------------------------------------------
439ab213215SJeremy L Thompson // Read quadrature point data
440ab213215SJeremy L Thompson //------------------------------------------------------------------------------
441c532df63SYohann inline __device__ void readQuads3d(const int elem, const int tidx,
442c532df63SYohann                                    const int tidy, const int comp,
4437f823360Sjeremylt                                    const int dim, const int nelem,
4447f823360Sjeremylt                                    const CeedScalar *d_U, CeedScalar *r_U) {
445c532df63SYohann   for (int i = 0; i < Q1D; i++)
44618d499f1SYohann     r_U[i] = (tidx < Q1D && tidy < Q1D) ?
44718d499f1SYohann               d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
44818d499f1SYohann               comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0;
44918d499f1SYohann   for (int i = Q1D; i < P1D; i++)
45018d499f1SYohann     r_U[i] = 0.0;
451c532df63SYohann }
452c532df63SYohann 
453ab213215SJeremy L Thompson //------------------------------------------------------------------------------
454ab213215SJeremy L Thompson // Write quadrature point data
455ab213215SJeremy L Thompson //------------------------------------------------------------------------------
456c532df63SYohann inline __device__ void writeQuads3d(const int elem, const int tidx,
457c532df63SYohann                                     const int tidy, const int comp,
4587f823360Sjeremylt                                     const int dim, const int nelem,
4597f823360Sjeremylt                                     const CeedScalar *r_V, CeedScalar *d_V) {
46018d499f1SYohann   if (tidx < Q1D && tidy < Q1D) {
461c532df63SYohann     for (int i = 0; i < Q1D; i++)
462c532df63SYohann       d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
463c532df63SYohann           dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i];
464c532df63SYohann   }
46518d499f1SYohann }
466c532df63SYohann 
467ab213215SJeremy L Thompson //------------------------------------------------------------------------------
468ab213215SJeremy L Thompson // 3D tensor contract x
469ab213215SJeremy L Thompson //------------------------------------------------------------------------------
470c532df63SYohann inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
471698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
47218d499f1SYohann                                    const CeedScalar *U,
47318d499f1SYohann                                    const CeedScalar *B,
4747f823360Sjeremylt                                    CeedScalar *V) {
475c532df63SYohann   for (int k = 0; k < P1D; ++k) {
47618d499f1SYohann     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
477c532df63SYohann     __syncthreads();
478c532df63SYohann     V[k] = 0.0;
47918d499f1SYohann     if (tidx < Q1D && tidy < P1D)
480ab213215SJeremy L Thompson       for (int i = 0; i < P1D; ++i)
48118d499f1SYohann         V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
482c532df63SYohann     __syncthreads();
483c532df63SYohann   }
484c532df63SYohann }
485c532df63SYohann 
486ab213215SJeremy L Thompson //------------------------------------------------------------------------------
487ab213215SJeremy L Thompson // 3D tensor contract y
488ab213215SJeremy L Thompson //------------------------------------------------------------------------------
489c532df63SYohann inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
490698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
49118d499f1SYohann                                    const CeedScalar *U,
49218d499f1SYohann                                    const CeedScalar *B,
4937f823360Sjeremylt                                    CeedScalar *V) {
494c532df63SYohann   for (int k = 0; k < P1D; ++k) {
49518d499f1SYohann     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
496c532df63SYohann     __syncthreads();
497c532df63SYohann     V[k] = 0.0;
49818d499f1SYohann     if (tidx < Q1D && tidy < Q1D)
499ab213215SJeremy L Thompson       for (int i = 0; i < P1D; ++i)
50018d499f1SYohann         V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
501c532df63SYohann     __syncthreads();
502c532df63SYohann   }
503c532df63SYohann }
504c532df63SYohann 
505ab213215SJeremy L Thompson //------------------------------------------------------------------------------
506ab213215SJeremy L Thompson // 3D tensor contract z
507ab213215SJeremy L Thompson //------------------------------------------------------------------------------
508c532df63SYohann inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
509698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
51018d499f1SYohann                                    const CeedScalar *U,
51118d499f1SYohann                                    const CeedScalar *B,
5127f823360Sjeremylt                                    CeedScalar *V) {
513c532df63SYohann   for (int k = 0; k < Q1D; ++k) {
514c532df63SYohann     V[k] = 0.0;
51518d499f1SYohann     if (tidx < Q1D && tidy < Q1D)
516ab213215SJeremy L Thompson       for (int i = 0; i < P1D; ++i)
517ab213215SJeremy L Thompson         V[k] += B[i + k*P1D] * U[i]; // Contract z direction
518c532df63SYohann   }
51918d499f1SYohann   for (int k = Q1D; k < P1D; ++k)
52018d499f1SYohann     V[k] = 0.0;
521c532df63SYohann }
522c532df63SYohann 
523ab213215SJeremy L Thompson //------------------------------------------------------------------------------
524ab213215SJeremy L Thompson // 3D transpose tensor contract z
525ab213215SJeremy L Thompson //------------------------------------------------------------------------------
526c532df63SYohann inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
527698ebc35SYohann Dudouit                                             const int tidy, const int tidz,
52818d499f1SYohann                                             const CeedScalar *U,
52918d499f1SYohann                                             const CeedScalar *B,
53018d499f1SYohann                                             CeedScalar *V) {
53118d499f1SYohann   for (int k = 0; k < P1D; ++k) {
532c532df63SYohann     V[k] = 0.0;
53318d499f1SYohann     if (tidx < Q1D && tidy < Q1D)
534ab213215SJeremy L Thompson       for (int i = 0; i < Q1D; ++i)
535ab213215SJeremy L Thompson         V[k] += B[k + i*P1D] * U[i]; // Contract z direction
536c532df63SYohann   }
53718d499f1SYohann   for (int k = P1D; k < Q1D; ++k)
53818d499f1SYohann     V[k] = 0.0;
539c532df63SYohann }
540c532df63SYohann 
541ab213215SJeremy L Thompson //------------------------------------------------------------------------------
542ab213215SJeremy L Thompson // 3D transpose tensor contract y
543ab213215SJeremy L Thompson //------------------------------------------------------------------------------
544c532df63SYohann inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
545698ebc35SYohann Dudouit                                             const int tidy, const int tidz,
54618d499f1SYohann                                             const CeedScalar *U,
54718d499f1SYohann                                             const CeedScalar *B,
54818d499f1SYohann                                             CeedScalar *V) {
549c532df63SYohann   for (int k = 0; k < P1D; ++k) {
55018d499f1SYohann     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
551c532df63SYohann     __syncthreads();
552c532df63SYohann     V[k] = 0.0;
55318d499f1SYohann     if (tidx < Q1D && tidy < P1D)
554ab213215SJeremy L Thompson       for (int i = 0; i < Q1D; ++i)
55518d499f1SYohann         V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
556c532df63SYohann     __syncthreads();
557c532df63SYohann   }
558c532df63SYohann }
559c532df63SYohann 
560ab213215SJeremy L Thompson //------------------------------------------------------------------------------
561ab213215SJeremy L Thompson // 3D transpose tensor contract x
562ab213215SJeremy L Thompson //------------------------------------------------------------------------------
563c532df63SYohann inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
564698ebc35SYohann Dudouit                                             const int tidy, const int tidz,
56518d499f1SYohann                                             const CeedScalar *U,
56618d499f1SYohann                                             const CeedScalar *B,
56718d499f1SYohann                                             CeedScalar *V) {
568c532df63SYohann   for (int k = 0; k < P1D; ++k) {
56918d499f1SYohann     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
570c532df63SYohann     __syncthreads();
571c532df63SYohann     V[k] = 0.0;
57218d499f1SYohann     if (tidx < P1D && tidy < P1D)
573ab213215SJeremy L Thompson       for (int i = 0; i < Q1D; ++i)
57418d499f1SYohann         V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
575c532df63SYohann     __syncthreads();
576c532df63SYohann   }
577c532df63SYohann }
578c532df63SYohann 
579ab213215SJeremy L Thompson //------------------------------------------------------------------------------
580ab213215SJeremy L Thompson // 3D interpolate to quadrature points
581ab213215SJeremy L Thompson //------------------------------------------------------------------------------
582c532df63SYohann inline __device__ void interp3d(const CeedInt nelem, const int transpose,
5837f823360Sjeremylt                                 const CeedScalar *c_B,
5847f823360Sjeremylt                                 const CeedScalar *__restrict__ d_U,
585c532df63SYohann                                 CeedScalar *__restrict__ d_V,
586c532df63SYohann                                 CeedScalar *slice) {
58718d499f1SYohann   CeedScalar r_V[T1D];
58818d499f1SYohann   CeedScalar r_t[T1D];
589c532df63SYohann 
590c532df63SYohann   const int tidx = threadIdx.x;
591c532df63SYohann   const int tidy = threadIdx.y;
592698ebc35SYohann Dudouit   const int tidz = threadIdx.z;
593698ebc35SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
594698ebc35SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
595698ebc35SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
596c532df63SYohann 
597698ebc35SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
598698ebc35SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
59918d499f1SYohann     for (int i = 0; i < T1D; ++i) {
600c532df63SYohann       r_V[i] = 0.0;
601c532df63SYohann       r_t[i] = 0.0;
602c532df63SYohann     }
603c532df63SYohann     if (!transpose) {
604c532df63SYohann       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
605698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
606698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
607698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
608c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
609c532df63SYohann     } else {
610c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
611698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
612698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
613698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
614c532df63SYohann       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
615c532df63SYohann     }
616c532df63SYohann   }
617c532df63SYohann }
618c532df63SYohann 
619ab213215SJeremy L Thompson //------------------------------------------------------------------------------
620ab213215SJeremy L Thompson // 3D derivatives at quadrature points
621ab213215SJeremy L Thompson //------------------------------------------------------------------------------
622c532df63SYohann inline __device__ void grad3d(const CeedInt nelem, const int transpose,
623c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
6247f823360Sjeremylt                               const CeedScalar *__restrict__ d_U,
6257f823360Sjeremylt                               CeedScalar *__restrict__ d_V,
626c532df63SYohann                               CeedScalar *slice) {
627ab213215SJeremy L Thompson   // Use P1D for one of these
62818d499f1SYohann   CeedScalar r_U[T1D];
62918d499f1SYohann   CeedScalar r_V[T1D];
63018d499f1SYohann   CeedScalar r_t[T1D];
631c532df63SYohann 
632c532df63SYohann   const int tidx = threadIdx.x;
633c532df63SYohann   const int tidy = threadIdx.y;
634698ebc35SYohann Dudouit   const int tidz = threadIdx.z;
635698ebc35SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
636698ebc35SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
637698ebc35SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
638c532df63SYohann   int dim;
639c532df63SYohann 
640698ebc35SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
641698ebc35SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
64218d499f1SYohann     for (int i = 0; i < T1D; ++i) {
64318d499f1SYohann       r_U[i] = 0.0;
64418d499f1SYohann       r_V[i] = 0.0;
64518d499f1SYohann       r_t[i] = 0.0;
64618d499f1SYohann     }
647c532df63SYohann     if (!transpose) {
648c532df63SYohann       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
649698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
650698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
651698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
652c532df63SYohann       dim = 0;
653c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
654698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
655698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
656698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
657c532df63SYohann       dim = 1;
658c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
659698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
660698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
661698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
662c532df63SYohann       dim = 2;
663c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
664c532df63SYohann     } else {
665c532df63SYohann       dim = 0;
666c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
667698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
668698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
669698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
670c532df63SYohann       dim = 1;
671c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
672698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
673698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
674698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
675c532df63SYohann       add(r_V, r_t);
676c532df63SYohann       dim = 2;
677c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
678698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
679698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
680698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
681c532df63SYohann       add(r_V, r_t);
682c532df63SYohann       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
683c532df63SYohann     }
684c532df63SYohann   }
685c532df63SYohann }
686c532df63SYohann 
687ab213215SJeremy L Thompson //------------------------------------------------------------------------------
688ab213215SJeremy L Thompson // 3D quadrature weights
689ab213215SJeremy L Thompson //------------------------------------------------------------------------------
690ab213215SJeremy L Thompson __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
691ab213215SJeremy L Thompson                          CeedScalar *w) {
692ab213215SJeremy L Thompson   const int i = threadIdx.x;
693ab213215SJeremy L Thompson   const int j = threadIdx.y;
694ab213215SJeremy L Thompson   const int k = threadIdx.z;
695ab213215SJeremy L Thompson   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
696ab213215SJeremy L Thompson   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
697ab213215SJeremy L Thompson     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
698ab213215SJeremy L Thompson     w[ind] = weight;
699ab213215SJeremy L Thompson   }
700ab213215SJeremy L Thompson }
701ab213215SJeremy L Thompson 
702ab213215SJeremy L Thompson 
703ab213215SJeremy L Thompson //------------------------------------------------------------------------------
704ab213215SJeremy L Thompson // Basis kernels
705ab213215SJeremy L Thompson //------------------------------------------------------------------------------
706ab213215SJeremy L Thompson 
707ab213215SJeremy L Thompson //------------------------------------------------------------------------------
708ab213215SJeremy L Thompson // Interp kernel by dim
709ab213215SJeremy L Thompson //------------------------------------------------------------------------------
710c532df63SYohann extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
7117f823360Sjeremylt                                   const CeedScalar *c_B,
7127f823360Sjeremylt                                   const CeedScalar *__restrict__ d_U,
713c532df63SYohann                                   CeedScalar *__restrict__ d_V) {
714074be161SYohann Dudouit   extern __shared__ double slice[];
715c532df63SYohann   if (BASIS_DIM == 1) {
716c532df63SYohann     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
717c532df63SYohann   } else if (BASIS_DIM == 2) {
718c532df63SYohann     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
719c532df63SYohann   } else if (BASIS_DIM == 3) {
720c532df63SYohann     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
721c532df63SYohann   }
722c532df63SYohann }
723c532df63SYohann 
724ab213215SJeremy L Thompson //------------------------------------------------------------------------------
725ab213215SJeremy L Thompson // Grad kernel by dim
726ab213215SJeremy L Thompson //------------------------------------------------------------------------------
727c532df63SYohann extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
728c532df63SYohann                                 const CeedScalar *c_B, const CeedScalar *c_G,
7297f823360Sjeremylt                                 const CeedScalar *__restrict__ d_U,
7307f823360Sjeremylt                                 CeedScalar *__restrict__ d_V) {
731074be161SYohann Dudouit   extern __shared__ double slice[];
732c532df63SYohann   if (BASIS_DIM == 1) {
733c532df63SYohann     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
734c532df63SYohann   } else if (BASIS_DIM == 2) {
735c532df63SYohann     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
736c532df63SYohann   } else if (BASIS_DIM == 3) {
737c532df63SYohann     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
738c532df63SYohann   }
739c532df63SYohann }
740c532df63SYohann 
741ab213215SJeremy L Thompson //------------------------------------------------------------------------------
742ab213215SJeremy L Thompson // Weight kernels by dim
743ab213215SJeremy L Thompson //------------------------------------------------------------------------------
744c532df63SYohann extern "C" __global__ void weight(const CeedInt nelem,
7457f823360Sjeremylt                                   const CeedScalar *__restrict__ qweight1d,
7467f823360Sjeremylt                                   CeedScalar *__restrict__ v) {
747c532df63SYohann   if (BASIS_DIM == 1) {
748c532df63SYohann     weight1d(nelem, qweight1d, v);
749c532df63SYohann   } else if (BASIS_DIM == 2) {
750c532df63SYohann     weight2d(nelem, qweight1d, v);
751c532df63SYohann   } else if (BASIS_DIM == 3) {
752c532df63SYohann     weight3d(nelem, qweight1d, v);
753c532df63SYohann   }
754c532df63SYohann }
755c532df63SYohann 
756c532df63SYohann );
757cb0b5415Sjeremylt // *INDENT-ON*
758c532df63SYohann 
759ab213215SJeremy L Thompson //------------------------------------------------------------------------------
760ab213215SJeremy L Thompson // Device initalization
761ab213215SJeremy L Thompson //------------------------------------------------------------------------------
762c532df63SYohann int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
763c532df63SYohann                        CeedScalar **c_B);
764c532df63SYohann int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
7657f823360Sjeremylt                            CeedInt Q1d, CeedScalar **c_B_ptr,
7667f823360Sjeremylt                            CeedScalar **c_G_ptr);
767c532df63SYohann 
768ab213215SJeremy L Thompson //------------------------------------------------------------------------------
769ab213215SJeremy L Thompson // Apply basis
770ab213215SJeremy L Thompson //------------------------------------------------------------------------------
771c532df63SYohann int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem,
772c532df63SYohann                                      CeedTransposeMode tmode,
7737f823360Sjeremylt                                      CeedEvalMode emode, CeedVector u,
7747f823360Sjeremylt                                      CeedVector v) {
775c532df63SYohann   int ierr;
776c532df63SYohann   Ceed ceed;
777*e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
778c532df63SYohann   Ceed_Cuda_shared *ceed_Cuda;
779*e15f9bd0SJeremy L Thompson   CeedGetData(ceed, &ceed_Cuda); CeedChkBackend(ierr);
780c532df63SYohann   CeedBasis_Cuda_shared *data;
781*e15f9bd0SJeremy L Thompson   CeedBasisGetData(basis, &data); CeedChkBackend(ierr);
782c532df63SYohann   const CeedInt transpose = tmode == CEED_TRANSPOSE;
7834247ecf3SYohann Dudouit   CeedInt dim, ncomp;
784*e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr);
785*e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr);
786c532df63SYohann 
787ab213215SJeremy L Thompson   // Read vectors
788c532df63SYohann   const CeedScalar *d_u;
789c532df63SYohann   CeedScalar *d_v;
790c532df63SYohann   if (emode != CEED_EVAL_WEIGHT) {
791*e15f9bd0SJeremy L Thompson     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChkBackend(ierr);
792c532df63SYohann   }
793*e15f9bd0SJeremy L Thompson   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr);
794c532df63SYohann 
795ab213215SJeremy L Thompson   // Clear v for transpose mode
796c532df63SYohann   if (tmode == CEED_TRANSPOSE) {
797c532df63SYohann     CeedInt length;
798*e15f9bd0SJeremy L Thompson     ierr = CeedVectorGetLength(v, &length); CeedChkBackend(ierr);
799*e15f9bd0SJeremy L Thompson     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChkBackend(ierr);
800c532df63SYohann   }
801ab213215SJeremy L Thompson 
802ab213215SJeremy L Thompson   // Apply basis operation
803ab213215SJeremy L Thompson   switch (emode) {
804ab213215SJeremy L Thompson   case CEED_EVAL_INTERP: {
805c532df63SYohann     CeedInt P1d, Q1d;
806*e15f9bd0SJeremy L Thompson     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr);
807*e15f9bd0SJeremy L Thompson     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr);
80818d499f1SYohann     CeedInt thread1d = CeedIntMax(Q1d, P1d);
809c532df63SYohann     ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
810*e15f9bd0SJeremy L Thompson     CeedChkBackend(ierr);
811cb0b5415Sjeremylt     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
812ccf0fe6fSjeremylt                           &d_u, &d_v
813ccf0fe6fSjeremylt                          };
8144d537eeaSYohann     if (dim == 1) {
815d94769d2SYohann Dudouit       CeedInt elemsPerBlock = 32;
8164d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8174d537eeaSYohann                                              ? 1 : 0 );
81818d499f1SYohann       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
81918d499f1SYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, 1,
8204d537eeaSYohann                                         elemsPerBlock, sharedMem,
821*e15f9bd0SJeremy L Thompson                                         interpargs); CeedChkBackend(ierr);
822074be161SYohann Dudouit     } else if (dim == 2) {
8234247ecf3SYohann Dudouit       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
8240f70cdf6SJeremy L Thompson       // elemsPerBlock must be at least 1
82518d499f1SYohann       CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1);
8264d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8274d537eeaSYohann                                              ? 1 : 0 );
82818d499f1SYohann       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
82918d499f1SYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, thread1d,
8304d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
831*e15f9bd0SJeremy L Thompson                                         interpargs); CeedChkBackend(ierr);
832074be161SYohann Dudouit     } else if (dim == 3) {
8333f63d318SYohann Dudouit       CeedInt elemsPerBlock = 1;
8344d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8354d537eeaSYohann                                              ? 1 : 0 );
83618d499f1SYohann       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
83718d499f1SYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, thread1d,
8384d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
839*e15f9bd0SJeremy L Thompson                                         interpargs); CeedChkBackend(ierr);
840074be161SYohann Dudouit     }
841ab213215SJeremy L Thompson   } break;
842ab213215SJeremy L Thompson   case CEED_EVAL_GRAD: {
843c532df63SYohann     CeedInt P1d, Q1d;
844*e15f9bd0SJeremy L Thompson     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr);
845*e15f9bd0SJeremy L Thompson     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr);
84618d499f1SYohann     CeedInt thread1d = CeedIntMax(Q1d, P1d);
847c532df63SYohann     ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
848c532df63SYohann                                   Q1d, &data->c_B, &data->c_G);
849*e15f9bd0SJeremy L Thompson     CeedChkBackend(ierr);
850cb0b5415Sjeremylt     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
851ccf0fe6fSjeremylt                         &data->c_G, &d_u, &d_v
852ccf0fe6fSjeremylt                        };
8534d537eeaSYohann     if (dim == 1) {
854d94769d2SYohann Dudouit       CeedInt elemsPerBlock = 32;
8554d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8564d537eeaSYohann                                              ? 1 : 0 );
85718d499f1SYohann       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
85818d499f1SYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, 1,
859ab213215SJeremy L Thompson                                         elemsPerBlock, sharedMem, gradargs);
860*e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
861074be161SYohann Dudouit     } else if (dim == 2) {
8624247ecf3SYohann Dudouit       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
8630f70cdf6SJeremy L Thompson       // elemsPerBlock must be at least 1
86418d499f1SYohann       CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1);
8654d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8664d537eeaSYohann                                              ? 1 : 0 );
86718d499f1SYohann       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
86818d499f1SYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, thread1d,
8694d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
870*e15f9bd0SJeremy L Thompson                                         gradargs); CeedChkBackend(ierr);
871074be161SYohann Dudouit     } else if (dim == 3) {
8723f63d318SYohann Dudouit       CeedInt elemsPerBlock = 1;
8734d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8744d537eeaSYohann                                              ? 1 : 0 );
87518d499f1SYohann       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
87618d499f1SYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, thread1d,
8774d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
878*e15f9bd0SJeremy L Thompson                                         gradargs); CeedChkBackend(ierr);
879074be161SYohann Dudouit     }
880ab213215SJeremy L Thompson   } break;
881ab213215SJeremy L Thompson   case CEED_EVAL_WEIGHT: {
882074be161SYohann Dudouit     CeedInt Q1d;
883*e15f9bd0SJeremy L Thompson     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr);
884c532df63SYohann     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
885074be161SYohann Dudouit     if (dim == 1) {
886074be161SYohann Dudouit       const CeedInt elemsPerBlock = 32/Q1d;
8874d537eeaSYohann       const CeedInt gridsize = nelem/elemsPerBlock + ( (
8884d537eeaSYohann                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
8897f823360Sjeremylt       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d,
8907f823360Sjeremylt                                   elemsPerBlock, 1, weightargs);
891*e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
892074be161SYohann Dudouit     } else if (dim == 2) {
893717ff8a3SYohann Dudouit       const CeedInt optElems = 32/(Q1d*Q1d);
894717ff8a3SYohann Dudouit       const CeedInt elemsPerBlock = optElems>0?optElems:1;
8954d537eeaSYohann       const CeedInt gridsize = nelem/elemsPerBlock + ( (
8964d537eeaSYohann                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
8974d537eeaSYohann       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d,
8984d537eeaSYohann                                   elemsPerBlock, weightargs);
899*e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
900074be161SYohann Dudouit     } else if (dim == 3) {
901074be161SYohann Dudouit       const CeedInt gridsize = nelem;
9024d537eeaSYohann       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
9034d537eeaSYohann                                   weightargs);
904*e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
905074be161SYohann Dudouit     }
906ab213215SJeremy L Thompson   } break;
907ab213215SJeremy L Thompson   // LCOV_EXCL_START
908ab213215SJeremy L Thompson   // Evaluate the divergence to/from the quadrature points
909ab213215SJeremy L Thompson   case CEED_EVAL_DIV:
910*e15f9bd0SJeremy L Thompson     return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
911ab213215SJeremy L Thompson   // Evaluate the curl to/from the quadrature points
912ab213215SJeremy L Thompson   case CEED_EVAL_CURL:
913*e15f9bd0SJeremy L Thompson     return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
914ab213215SJeremy L Thompson   // Take no action, BasisApply should not have been called
915ab213215SJeremy L Thompson   case CEED_EVAL_NONE:
916*e15f9bd0SJeremy L Thompson     return CeedError(ceed, CEED_ERROR_BACKEND,
917ab213215SJeremy L Thompson                      "CEED_EVAL_NONE does not make sense in this context");
918ab213215SJeremy L Thompson     // LCOV_EXCL_STOP
919c532df63SYohann   }
920c532df63SYohann 
921ab213215SJeremy L Thompson   // Restore vectors
922c532df63SYohann   if (emode != CEED_EVAL_WEIGHT) {
923*e15f9bd0SJeremy L Thompson     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChkBackend(ierr);
924c532df63SYohann   }
925*e15f9bd0SJeremy L Thompson   ierr = CeedVectorRestoreArray(v, &d_v); CeedChkBackend(ierr);
926*e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
927c532df63SYohann }
928c532df63SYohann 
929ab213215SJeremy L Thompson //------------------------------------------------------------------------------
930ab213215SJeremy L Thompson // Destroy basis
931ab213215SJeremy L Thompson //------------------------------------------------------------------------------
932c532df63SYohann static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
933c532df63SYohann   int ierr;
934c532df63SYohann   Ceed ceed;
935*e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
936c532df63SYohann 
937c532df63SYohann   CeedBasis_Cuda_shared *data;
938*e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetData(basis, &data); CeedChkBackend(ierr);
939c532df63SYohann 
940c532df63SYohann   CeedChk_Cu(ceed, cuModuleUnload(data->module));
941c532df63SYohann 
942c532df63SYohann   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr);
943c532df63SYohann   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr);
944c532df63SYohann   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr);
9451958eb7cSJeremy L Thompson   ierr = cudaFree(data->d_collograd1d); CeedChk_Cu(ceed, ierr);
946c532df63SYohann 
947*e15f9bd0SJeremy L Thompson   ierr = CeedFree(&data); CeedChkBackend(ierr);
948c532df63SYohann 
949*e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
950c532df63SYohann }
951c532df63SYohann 
952ab213215SJeremy L Thompson //------------------------------------------------------------------------------
953ab213215SJeremy L Thompson // Create tensor basis
954ab213215SJeremy L Thompson //------------------------------------------------------------------------------
955c532df63SYohann int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
956c532df63SYohann                                         const CeedScalar *interp1d,
957c532df63SYohann                                         const CeedScalar *grad1d,
958c532df63SYohann                                         const CeedScalar *qref1d,
959c532df63SYohann                                         const CeedScalar *qweight1d,
960c532df63SYohann                                         CeedBasis basis) {
961c532df63SYohann   int ierr;
962c532df63SYohann   Ceed ceed;
963*e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
964c532df63SYohann   CeedBasis_Cuda_shared *data;
965*e15f9bd0SJeremy L Thompson   ierr = CeedCalloc(1, &data); CeedChkBackend(ierr);
966c532df63SYohann 
967ab213215SJeremy L Thompson   // Copy basis data to GPU
968c532df63SYohann   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
969c532df63SYohann   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr);
970c532df63SYohann   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
971c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
972c532df63SYohann 
973c532df63SYohann   const CeedInt iBytes = qBytes * P1d;
974c532df63SYohann   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr);
975c532df63SYohann   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
976c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
977c532df63SYohann 
978c532df63SYohann   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr);
979c532df63SYohann   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
980c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
981c532df63SYohann 
982ab213215SJeremy L Thompson   // Compute collocated gradient and copy to GPU
983ac421f39SYohann   data->d_collograd1d = NULL;
984ac421f39SYohann   if (dim == 3 && Q1d >= P1d) {
985ac421f39SYohann     CeedScalar *collograd1d;
986*e15f9bd0SJeremy L Thompson     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChkBackend(ierr);
987*e15f9bd0SJeremy L Thompson     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChkBackend(ierr);
988ac421f39SYohann     ierr = cudaMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
989ac421f39SYohann     CeedChk_Cu(ceed, ierr);
990ac421f39SYohann     ierr = cudaMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
991ac421f39SYohann                       cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
992*e15f9bd0SJeremy L Thompson     ierr = CeedFree(&collograd1d); CeedChkBackend(ierr);
993ac421f39SYohann   }
994ac421f39SYohann 
995ab213215SJeremy L Thompson   // Compile basis kernels
996c532df63SYohann   CeedInt ncomp;
997*e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr);
99818d499f1SYohann   ierr = CeedCompileCuda(ceed, kernelsShared, &data->module, 8,
999c532df63SYohann                          "Q1D", Q1d,
1000c532df63SYohann                          "P1D", P1d,
100118d499f1SYohann                          "T1D", CeedIntMax(Q1d, P1d),
1002c532df63SYohann                          "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
1003c532df63SYohann                              Q1d : P1d, dim),
1004c532df63SYohann                          "BASIS_DIM", dim,
1005c532df63SYohann                          "BASIS_NCOMP", ncomp,
1006c532df63SYohann                          "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
1007c532df63SYohann                          "BASIS_NQPT", CeedIntPow(Q1d, dim)
1008*e15f9bd0SJeremy L Thompson                         ); CeedChkBackend(ierr);
10094a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp);
1010*e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
10114a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad);
1012*e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
10134a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight);
1014*e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
1015c532df63SYohann 
1016*e15f9bd0SJeremy L Thompson   ierr = CeedBasisSetData(basis, data); CeedChkBackend(ierr);
1017ab213215SJeremy L Thompson 
1018ab213215SJeremy L Thompson   // Register backend functions
1019c532df63SYohann   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
1020c532df63SYohann                                 CeedBasisApplyTensor_Cuda_shared);
1021*e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
1022c532df63SYohann   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
1023*e15f9bd0SJeremy L Thompson                                 CeedBasisDestroy_Cuda_shared); CeedChkBackend(ierr);
1024*e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1025c532df63SYohann }
1026ab213215SJeremy L Thompson //------------------------------------------------------------------------------
1027