xref: /libCEED/rust/libceed-sys/c-src/backends/cuda-shared/ceed-cuda-shared-basis.c (revision ec3da8bcb94d9f0073544b37b5081a06981a86f7)
1c532df63SYohann // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2c532df63SYohann // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3c532df63SYohann // All Rights reserved. See files LICENSE and NOTICE for details.
4c532df63SYohann //
5c532df63SYohann // This file is part of CEED, a collection of benchmarks, miniapps, software
6c532df63SYohann // libraries and APIs for efficient high-order finite element and spectral
7c532df63SYohann // element discretizations for exascale applications. For more information and
8c532df63SYohann // source code availability see http://github.com/ceed.
9c532df63SYohann //
10c532df63SYohann // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11c532df63SYohann // a collaborative effort of two U.S. Department of Energy organizations (Office
12c532df63SYohann // of Science and the National Nuclear Security Administration) responsible for
13c532df63SYohann // the planning and preparation of a capable exascale ecosystem, including
14c532df63SYohann // software, applications, hardware, advanced system engineering and early
15c532df63SYohann // testbed platforms, in support of the nation's exascale computing imperative.
16c532df63SYohann 
17*ec3da8bcSJed Brown #include <ceed/ceed.h>
18*ec3da8bcSJed Brown #include <ceed/backend.h>
193d576824SJeremy L Thompson #include <cuda.h>
203d576824SJeremy L Thompson #include <cuda_runtime.h>
213d576824SJeremy L Thompson #include <stddef.h>
22c532df63SYohann #include "ceed-cuda-shared.h"
233d576824SJeremy L Thompson #include "../cuda/ceed-cuda.h"
24c532df63SYohann 
25ab213215SJeremy L Thompson //------------------------------------------------------------------------------
26ab213215SJeremy L Thompson // Shared mem kernels
27ab213215SJeremy L Thompson //------------------------------------------------------------------------------
28cb0b5415Sjeremylt // *INDENT-OFF*
29c532df63SYohann static const char *kernelsShared = QUOTE(
30c532df63SYohann 
31ab213215SJeremy L Thompson //------------------------------------------------------------------------------
32ab213215SJeremy L Thompson // Sum input into output
33ab213215SJeremy L Thompson //------------------------------------------------------------------------------
34c532df63SYohann inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
3518d499f1SYohann   for (int i = 0; i < P1D; i++)
36c532df63SYohann     r_V[i] += r_U[i];
37c532df63SYohann }
38c532df63SYohann 
39ab213215SJeremy L Thompson //------------------------------------------------------------------------------
40ab213215SJeremy L Thompson // 1D
41ab213215SJeremy L Thompson //------------------------------------------------------------------------------
42c532df63SYohann 
43ab213215SJeremy L Thompson //------------------------------------------------------------------------------
44ab213215SJeremy L Thompson // Read DoFs
45ab213215SJeremy L Thompson //------------------------------------------------------------------------------
46c532df63SYohann inline __device__ void readDofs1d(const int elem, const int tidx,
47d94769d2SYohann Dudouit                                   const int tidy, const int tidz,const int comp,
487f823360Sjeremylt                                   const int nelem, const CeedScalar *d_U,
497f823360Sjeremylt                                   CeedScalar *slice) {
50c532df63SYohann   for (int i = 0; i < P1D; i++)
5118d499f1SYohann     slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem];
52c532df63SYohann   for (int i = P1D; i < Q1D; i++)
5318d499f1SYohann     slice[i + tidz*T1D] = 0.0;
54c532df63SYohann }
55c532df63SYohann 
56ab213215SJeremy L Thompson //------------------------------------------------------------------------------
57ab213215SJeremy L Thompson // Write DoFs
58ab213215SJeremy L Thompson //------------------------------------------------------------------------------
59c532df63SYohann inline __device__ void writeDofs1d(const int elem, const int tidx,
60c532df63SYohann                                    const int tidy, const int comp,
61288c0443SJeremy L Thompson                                    const int nelem, const CeedScalar &r_V,
62288c0443SJeremy L Thompson                                    CeedScalar *d_V) {
63ab213215SJeremy L Thompson   if (tidx<P1D)
6449fd234cSJeremy L Thompson     d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V;
65c532df63SYohann }
66c532df63SYohann 
67ab213215SJeremy L Thompson //------------------------------------------------------------------------------
68ab213215SJeremy L Thompson // Read quadrature point data
69ab213215SJeremy L Thompson //------------------------------------------------------------------------------
70c532df63SYohann inline __device__ void readQuads1d(const int elem, const int tidx,
71d94769d2SYohann Dudouit                                    const int tidy, const int tidz, const int comp,
72288c0443SJeremy L Thompson                                    const int dim, const int nelem,
73288c0443SJeremy L Thompson                                    const CeedScalar *d_U, CeedScalar *slice) {
74c532df63SYohann   for (int i = 0; i < Q1D; i++)
7518d499f1SYohann     slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
764d537eeaSYohann                             dim*BASIS_NCOMP*nelem*Q1D];
7718d499f1SYohann   for (int i = Q1D; i < P1D; i++)
7818d499f1SYohann     slice[i + tidz*T1D] = 0.0;
79c532df63SYohann }
80c532df63SYohann 
81ab213215SJeremy L Thompson //------------------------------------------------------------------------------
82ab213215SJeremy L Thompson // Write quadrature point data
83ab213215SJeremy L Thompson //------------------------------------------------------------------------------
84c532df63SYohann inline __device__ void writeQuads1d(const int elem, const int tidx,
85c532df63SYohann                                     const int tidy, const int comp,
86288c0443SJeremy L Thompson                                     const int dim, const int nelem,
87288c0443SJeremy L Thompson                                     const CeedScalar &r_V, CeedScalar *d_V) {
8818d499f1SYohann   if (tidx<Q1D)
89c532df63SYohann     d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
90c532df63SYohann }
91c532df63SYohann 
92ab213215SJeremy L Thompson //------------------------------------------------------------------------------
93ab213215SJeremy L Thompson // 1D tensor contraction
94ab213215SJeremy L Thompson //------------------------------------------------------------------------------
95c532df63SYohann inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
96d94769d2SYohann Dudouit                                    const int tidy, const int tidz,
97288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
98288c0443SJeremy L Thompson                                    CeedScalar &V) {
99c532df63SYohann   V = 0.0;
100ab213215SJeremy L Thompson   for (int i = 0; i < P1D; ++i)
10118d499f1SYohann     V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction
102c532df63SYohann }
103c532df63SYohann 
104ab213215SJeremy L Thompson //------------------------------------------------------------------------------
105ab213215SJeremy L Thompson // 1D transpose tensor contraction
106ab213215SJeremy L Thompson //------------------------------------------------------------------------------
107c532df63SYohann inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
108d94769d2SYohann Dudouit     const int tidy, const int tidz,
109c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
110c532df63SYohann   V = 0.0;
111ab213215SJeremy L Thompson   for (int i = 0; i < Q1D; ++i)
11218d499f1SYohann     V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction
113c532df63SYohann }
114c532df63SYohann 
115ab213215SJeremy L Thompson //------------------------------------------------------------------------------
116ab213215SJeremy L Thompson // 1D interpolate to quadrature points
117ab213215SJeremy L Thompson //------------------------------------------------------------------------------
118c532df63SYohann inline __device__ void interp1d(const CeedInt nelem, const int transpose,
119288c0443SJeremy L Thompson                                 const CeedScalar *c_B,
120288c0443SJeremy L Thompson                                 const CeedScalar *__restrict__ d_U,
121c532df63SYohann                                 CeedScalar *__restrict__ d_V,
122c532df63SYohann                                 CeedScalar *slice) {
123c532df63SYohann   CeedScalar r_V;
124c532df63SYohann   CeedScalar r_t;
125c532df63SYohann 
126c532df63SYohann   const int tidx = threadIdx.x;
127c532df63SYohann   const int tidy = threadIdx.y;
128d94769d2SYohann Dudouit   const int tidz = threadIdx.z;
129c532df63SYohann 
130c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
131c532df63SYohann        elem += gridDim.x*blockDim.z) {
132c532df63SYohann     for (int comp = 0; comp < BASIS_NCOMP; comp++) {
133c532df63SYohann       if (!transpose) {
134d94769d2SYohann Dudouit         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
135d94769d2SYohann Dudouit         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
136c532df63SYohann         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
137c532df63SYohann       } else {
138d94769d2SYohann Dudouit         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
139d94769d2SYohann Dudouit         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
140c532df63SYohann         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
141c532df63SYohann       }
142c532df63SYohann     }
143c532df63SYohann   }
144c532df63SYohann }
145c532df63SYohann 
146ab213215SJeremy L Thompson //------------------------------------------------------------------------------
147ab213215SJeremy L Thompson // 1D derivatives at quadrature points
148ab213215SJeremy L Thompson //------------------------------------------------------------------------------
149c532df63SYohann inline __device__ void grad1d(const CeedInt nelem, const int transpose,
150c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
151288c0443SJeremy L Thompson                               const CeedScalar *__restrict__ d_U,
152288c0443SJeremy L Thompson                               CeedScalar *__restrict__ d_V,
153c532df63SYohann                               CeedScalar *slice) {
154c532df63SYohann   CeedScalar r_U;
155c532df63SYohann   CeedScalar r_V;
156c532df63SYohann 
157c532df63SYohann   const int tidx = threadIdx.x;
158d94769d2SYohann Dudouit   const int tidy = threadIdx.y;
159d94769d2SYohann Dudouit   const int tidz = threadIdx.z;
160c532df63SYohann   int dim;
161c532df63SYohann 
162c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
163c532df63SYohann        elem += gridDim.x*blockDim.z) {
164c532df63SYohann     for(int comp = 0; comp < BASIS_NCOMP; comp++) {
165c532df63SYohann       if (!transpose) {
166d94769d2SYohann Dudouit         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
167d94769d2SYohann Dudouit         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
168c532df63SYohann         dim = 0;
169c532df63SYohann         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
170c532df63SYohann       } else {
171c532df63SYohann         dim = 0;
172d94769d2SYohann Dudouit         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
173d94769d2SYohann Dudouit         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
174c532df63SYohann         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
175c532df63SYohann       }
176c532df63SYohann     }
177c532df63SYohann   }
178c532df63SYohann }
179c532df63SYohann 
180ab213215SJeremy L Thompson //------------------------------------------------------------------------------
181ab213215SJeremy L Thompson // 1D Quadrature weights
182ab213215SJeremy L Thompson //------------------------------------------------------------------------------
183ab213215SJeremy L Thompson __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
184ab213215SJeremy L Thompson                          CeedScalar *w) {
185ab213215SJeremy L Thompson   const int tid = threadIdx.x;
186ab213215SJeremy L Thompson   const CeedScalar weight = qweight1d[tid];
187ab213215SJeremy L Thompson   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
188ab213215SJeremy L Thompson        elem += gridDim.x*blockDim.y) {
189ab213215SJeremy L Thompson     const int ind = elem*Q1D + tid;
190ab213215SJeremy L Thompson     w[ind] = weight;
191ab213215SJeremy L Thompson   }
192ab213215SJeremy L Thompson }
193ab213215SJeremy L Thompson 
194ab213215SJeremy L Thompson //------------------------------------------------------------------------------
195ab213215SJeremy L Thompson // 2D
196ab213215SJeremy L Thompson //------------------------------------------------------------------------------
197ab213215SJeremy L Thompson 
198ab213215SJeremy L Thompson //------------------------------------------------------------------------------
199ab213215SJeremy L Thompson // Read DoFs
200ab213215SJeremy L Thompson //------------------------------------------------------------------------------
201c532df63SYohann inline __device__ void readDofs2d(const int elem, const int tidx,
202c532df63SYohann                                   const int tidy, const int comp,
203288c0443SJeremy L Thompson                                   const int nelem, const CeedScalar *d_U,
204288c0443SJeremy L Thompson                                   CeedScalar &U) {
20549fd234cSJeremy L Thompson   U = (tidx<P1D && tidy<P1D) ?
20649fd234cSJeremy L Thompson       d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0;
207c532df63SYohann }
208c532df63SYohann 
209ab213215SJeremy L Thompson //------------------------------------------------------------------------------
210ab213215SJeremy L Thompson // Write DoFs
211ab213215SJeremy L Thompson //------------------------------------------------------------------------------
212c532df63SYohann inline __device__ void writeDofs2d(const int elem, const int tidx,
213c532df63SYohann                                    const int tidy, const int comp,
214288c0443SJeremy L Thompson                                    const int nelem, const CeedScalar &r_V,
215288c0443SJeremy L Thompson                                    CeedScalar *d_V) {
216ab213215SJeremy L Thompson   if (tidx<P1D && tidy<P1D)
21749fd234cSJeremy L Thompson     d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V;
218c532df63SYohann }
219c532df63SYohann 
220ab213215SJeremy L Thompson //------------------------------------------------------------------------------
221ab213215SJeremy L Thompson // Read quadrature point data
222ab213215SJeremy L Thompson //------------------------------------------------------------------------------
223c532df63SYohann inline __device__ void readQuads2d(const int elem, const int tidx,
224c532df63SYohann                                    const int tidy, const int comp,
225288c0443SJeremy L Thompson                                    const int dim, const int nelem,
226288c0443SJeremy L Thompson                                    const CeedScalar *d_U, CeedScalar &U ) {
22718d499f1SYohann   U = (tidx<Q1D && tidy<Q1D) ?
22818d499f1SYohann       d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
22918d499f1SYohann       dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0;
230c532df63SYohann }
231c532df63SYohann 
232ab213215SJeremy L Thompson //------------------------------------------------------------------------------
233ab213215SJeremy L Thompson // Write quadrature point data
234ab213215SJeremy L Thompson //------------------------------------------------------------------------------
235c532df63SYohann inline __device__ void writeQuads2d(const int elem, const int tidx,
236c532df63SYohann                                     const int tidy, const int comp,
237288c0443SJeremy L Thompson                                     const int dim, const int nelem,
238288c0443SJeremy L Thompson                                     const CeedScalar &r_V, CeedScalar *d_V) {
23918d499f1SYohann   if (tidx<Q1D && tidy<Q1D)
240c532df63SYohann     d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
241c532df63SYohann     dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V;
242c532df63SYohann }
243c532df63SYohann 
244ab213215SJeremy L Thompson //------------------------------------------------------------------------------
245ab213215SJeremy L Thompson // 2D tensor contraction x
246ab213215SJeremy L Thompson //------------------------------------------------------------------------------
247c532df63SYohann inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
2484247ecf3SYohann Dudouit                                    const int tidy, const int tidz,
249288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
250288c0443SJeremy L Thompson                                    CeedScalar &V) {
25118d499f1SYohann   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
252c532df63SYohann   __syncthreads();
253c532df63SYohann   V = 0.0;
25418d499f1SYohann   if (tidx < Q1D)
255ab213215SJeremy L Thompson     for (int i = 0; i < P1D; ++i)
25618d499f1SYohann       V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
257c532df63SYohann   __syncthreads();
258c532df63SYohann }
259c532df63SYohann 
260ab213215SJeremy L Thompson //------------------------------------------------------------------------------
261ab213215SJeremy L Thompson // 2D tensor contraction y
262ab213215SJeremy L Thompson //------------------------------------------------------------------------------
263c532df63SYohann inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
2644247ecf3SYohann Dudouit                                    const int tidy, const int tidz,
265288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
266288c0443SJeremy L Thompson                                    CeedScalar &V) {
26718d499f1SYohann   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
268c532df63SYohann   __syncthreads();
269c532df63SYohann   V = 0.0;
27018d499f1SYohann   if (tidy < Q1D)
271ab213215SJeremy L Thompson     for (int i = 0; i < P1D; ++i)
27218d499f1SYohann       V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
273c532df63SYohann   __syncthreads();
274c532df63SYohann }
275c532df63SYohann 
276ab213215SJeremy L Thompson //------------------------------------------------------------------------------
277ab213215SJeremy L Thompson // 2D transpose tensor contraction y
278ab213215SJeremy L Thompson //------------------------------------------------------------------------------
279c532df63SYohann inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
2804247ecf3SYohann Dudouit     const int tidy, const int tidz,
281c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
28218d499f1SYohann   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
283c532df63SYohann   __syncthreads();
284c532df63SYohann   V = 0.0;
285ab213215SJeremy L Thompson   if (tidy < P1D)
286ab213215SJeremy L Thompson     for (int i = 0; i < Q1D; ++i)
28718d499f1SYohann       V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
288c532df63SYohann   __syncthreads();
289c532df63SYohann }
290c532df63SYohann 
291ab213215SJeremy L Thompson //------------------------------------------------------------------------------
292ab213215SJeremy L Thompson // 2D transpose tensor contraction x
293ab213215SJeremy L Thompson //------------------------------------------------------------------------------
294c532df63SYohann inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
2954247ecf3SYohann Dudouit     const int tidy, const int tidz,
296c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
29718d499f1SYohann   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
298c532df63SYohann   __syncthreads();
299c532df63SYohann   V = 0.0;
300ab213215SJeremy L Thompson   if (tidx < P1D)
301ab213215SJeremy L Thompson     for (int i = 0; i < Q1D; ++i)
30218d499f1SYohann       V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
303c532df63SYohann   __syncthreads();
304c532df63SYohann }
305c532df63SYohann 
306ab213215SJeremy L Thompson //------------------------------------------------------------------------------
307ab213215SJeremy L Thompson // 2D interpolate to quadrature points
308ab213215SJeremy L Thompson //------------------------------------------------------------------------------
309c532df63SYohann inline __device__ void interp2d(const CeedInt nelem, const int transpose,
310288c0443SJeremy L Thompson                                 const CeedScalar *c_B,
311288c0443SJeremy L Thompson                                 const CeedScalar *__restrict__ d_U,
312c532df63SYohann                                 CeedScalar *__restrict__ d_V,
313c532df63SYohann                                 CeedScalar *slice) {
314c532df63SYohann   CeedScalar r_V;
315c532df63SYohann   CeedScalar r_t;
316c532df63SYohann 
317c532df63SYohann   const int tidx = threadIdx.x;
318c532df63SYohann   const int tidy = threadIdx.y;
3194247ecf3SYohann Dudouit   const int tidz = threadIdx.z;
3204247ecf3SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
3214247ecf3SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
3224247ecf3SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
323c532df63SYohann 
3244247ecf3SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
3254247ecf3SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
3264247ecf3SYohann Dudouit     const int comp = tidz%BASIS_NCOMP;
327c532df63SYohann     r_V = 0.0;
328c532df63SYohann     r_t = 0.0;
329c532df63SYohann     if (!transpose) {
330c532df63SYohann       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
3314247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
3324247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
333c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
334c532df63SYohann     } else {
335c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
3364247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
3374247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
338c532df63SYohann       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
339c532df63SYohann     }
340c532df63SYohann   }
341c532df63SYohann }
342c532df63SYohann 
343ab213215SJeremy L Thompson //------------------------------------------------------------------------------
344ab213215SJeremy L Thompson // 2D derivatives at quadrature points
345ab213215SJeremy L Thompson //------------------------------------------------------------------------------
346c532df63SYohann inline __device__ void grad2d(const CeedInt nelem, const int transpose,
347c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
3487f823360Sjeremylt                               const CeedScalar *__restrict__ d_U,
3497f823360Sjeremylt                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
350c532df63SYohann   CeedScalar r_U;
351c532df63SYohann   CeedScalar r_V;
352c532df63SYohann   CeedScalar r_t;
353c532df63SYohann 
354c532df63SYohann   const int tidx = threadIdx.x;
355c532df63SYohann   const int tidy = threadIdx.y;
3564247ecf3SYohann Dudouit   const int tidz = threadIdx.z;
3574247ecf3SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
3584247ecf3SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
3594247ecf3SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
360c532df63SYohann   int dim;
361c532df63SYohann 
3624247ecf3SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
3634247ecf3SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
364c532df63SYohann     if (!transpose) {
365c532df63SYohann       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
3664247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
3674247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
368c532df63SYohann       dim = 0;
369c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
3704247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
3714247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
372c532df63SYohann       dim = 1;
373c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
374c532df63SYohann     } else {
375c532df63SYohann       dim = 0;
376c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
3774247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
3784247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
379c532df63SYohann       dim = 1;
380c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
3814247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
3824247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
383c532df63SYohann       r_V += r_U;
384c532df63SYohann       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
385c532df63SYohann     }
386c532df63SYohann   }
387c532df63SYohann }
388c532df63SYohann 
389ab213215SJeremy L Thompson //------------------------------------------------------------------------------
390ab213215SJeremy L Thompson // 2D quadrature weights
391ab213215SJeremy L Thompson //------------------------------------------------------------------------------
392ab213215SJeremy L Thompson __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
393ab213215SJeremy L Thompson                          CeedScalar *w) {
394ab213215SJeremy L Thompson   const int i = threadIdx.x;
395ab213215SJeremy L Thompson   const int j = threadIdx.y;
396ab213215SJeremy L Thompson   const CeedScalar weight = qweight1d[i]*qweight1d[j];
397ab213215SJeremy L Thompson   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
398ab213215SJeremy L Thompson        elem += gridDim.x*blockDim.z) {
399ab213215SJeremy L Thompson     const int ind = elem*Q1D*Q1D + i + j*Q1D;
400ab213215SJeremy L Thompson     w[ind] = weight;
401ab213215SJeremy L Thompson   }
402ab213215SJeremy L Thompson }
403ab213215SJeremy L Thompson 
404ab213215SJeremy L Thompson //------------------------------------------------------------------------------
405ab213215SJeremy L Thompson // 3D
406ab213215SJeremy L Thompson //------------------------------------------------------------------------------
407ab213215SJeremy L Thompson 
408ab213215SJeremy L Thompson //------------------------------------------------------------------------------
409ab213215SJeremy L Thompson // Read DoFs
410ab213215SJeremy L Thompson //------------------------------------------------------------------------------
411c532df63SYohann inline __device__ void readDofs3d(const int elem, const int tidx,
412c532df63SYohann                                   const int tidy, const int comp,
4137f823360Sjeremylt                                   const int nelem, const CeedScalar *d_U,
4147f823360Sjeremylt                                   CeedScalar *r_U) {
415c532df63SYohann   for (int i = 0; i < P1D; i++)
416ab213215SJeremy L Thompson     r_U[i] = (tidx < P1D && tidy < P1D) ?
41749fd234cSJeremy L Thompson               d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
41849fd234cSJeremy L Thompson                   comp*P1D*P1D*P1D*nelem] : 0.0;
419c532df63SYohann   for (int i = P1D; i < Q1D; i++)
420c532df63SYohann     r_U[i] = 0.0;
421c532df63SYohann }
422c532df63SYohann 
423ab213215SJeremy L Thompson //------------------------------------------------------------------------------
42449fd234cSJeremy L Thompson // Write DoFs
42549fd234cSJeremy L Thompson //------------------------------------------------------------------------------
42649fd234cSJeremy L Thompson inline __device__ void writeDofs3d(const int elem, const int tidx,
42749fd234cSJeremy L Thompson                                    const int tidy, const int comp,
42849fd234cSJeremy L Thompson                                    const int nelem, const CeedScalar *r_V,
42949fd234cSJeremy L Thompson                                    CeedScalar *d_V) {
43049fd234cSJeremy L Thompson   if (tidx < P1D && tidy < P1D) {
43149fd234cSJeremy L Thompson     for (int i = 0; i < P1D; i++)
43249fd234cSJeremy L Thompson       d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
43349fd234cSJeremy L Thompson           comp*P1D*P1D*P1D*nelem] = r_V[i];
43449fd234cSJeremy L Thompson   }
43549fd234cSJeremy L Thompson }
43649fd234cSJeremy L Thompson 
43749fd234cSJeremy L Thompson //------------------------------------------------------------------------------
438ab213215SJeremy L Thompson // Read quadrature point data
439ab213215SJeremy L Thompson //------------------------------------------------------------------------------
440c532df63SYohann inline __device__ void readQuads3d(const int elem, const int tidx,
441c532df63SYohann                                    const int tidy, const int comp,
4427f823360Sjeremylt                                    const int dim, const int nelem,
4437f823360Sjeremylt                                    const CeedScalar *d_U, CeedScalar *r_U) {
444c532df63SYohann   for (int i = 0; i < Q1D; i++)
44518d499f1SYohann     r_U[i] = (tidx < Q1D && tidy < Q1D) ?
44618d499f1SYohann               d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
44718d499f1SYohann               comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0;
44818d499f1SYohann   for (int i = Q1D; i < P1D; i++)
44918d499f1SYohann     r_U[i] = 0.0;
450c532df63SYohann }
451c532df63SYohann 
452ab213215SJeremy L Thompson //------------------------------------------------------------------------------
453ab213215SJeremy L Thompson // Write quadrature point data
454ab213215SJeremy L Thompson //------------------------------------------------------------------------------
455c532df63SYohann inline __device__ void writeQuads3d(const int elem, const int tidx,
456c532df63SYohann                                     const int tidy, const int comp,
4577f823360Sjeremylt                                     const int dim, const int nelem,
4587f823360Sjeremylt                                     const CeedScalar *r_V, CeedScalar *d_V) {
45918d499f1SYohann   if (tidx < Q1D && tidy < Q1D) {
460c532df63SYohann     for (int i = 0; i < Q1D; i++)
461c532df63SYohann       d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
462c532df63SYohann           dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i];
463c532df63SYohann   }
46418d499f1SYohann }
465c532df63SYohann 
466ab213215SJeremy L Thompson //------------------------------------------------------------------------------
467ab213215SJeremy L Thompson // 3D tensor contract x
468ab213215SJeremy L Thompson //------------------------------------------------------------------------------
469c532df63SYohann inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
470698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
47118d499f1SYohann                                    const CeedScalar *U,
47218d499f1SYohann                                    const CeedScalar *B,
4737f823360Sjeremylt                                    CeedScalar *V) {
474c532df63SYohann   for (int k = 0; k < P1D; ++k) {
47518d499f1SYohann     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
476c532df63SYohann     __syncthreads();
477c532df63SYohann     V[k] = 0.0;
47818d499f1SYohann     if (tidx < Q1D && tidy < P1D)
479ab213215SJeremy L Thompson       for (int i = 0; i < P1D; ++i)
48018d499f1SYohann         V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
481c532df63SYohann     __syncthreads();
482c532df63SYohann   }
483c532df63SYohann }
484c532df63SYohann 
485ab213215SJeremy L Thompson //------------------------------------------------------------------------------
486ab213215SJeremy L Thompson // 3D tensor contract y
487ab213215SJeremy L Thompson //------------------------------------------------------------------------------
488c532df63SYohann inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
489698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
49018d499f1SYohann                                    const CeedScalar *U,
49118d499f1SYohann                                    const CeedScalar *B,
4927f823360Sjeremylt                                    CeedScalar *V) {
493c532df63SYohann   for (int k = 0; k < P1D; ++k) {
49418d499f1SYohann     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
495c532df63SYohann     __syncthreads();
496c532df63SYohann     V[k] = 0.0;
49718d499f1SYohann     if (tidx < Q1D && tidy < Q1D)
498ab213215SJeremy L Thompson       for (int i = 0; i < P1D; ++i)
49918d499f1SYohann         V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
500c532df63SYohann     __syncthreads();
501c532df63SYohann   }
502c532df63SYohann }
503c532df63SYohann 
504ab213215SJeremy L Thompson //------------------------------------------------------------------------------
505ab213215SJeremy L Thompson // 3D tensor contract z
506ab213215SJeremy L Thompson //------------------------------------------------------------------------------
507c532df63SYohann inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
508698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
50918d499f1SYohann                                    const CeedScalar *U,
51018d499f1SYohann                                    const CeedScalar *B,
5117f823360Sjeremylt                                    CeedScalar *V) {
512c532df63SYohann   for (int k = 0; k < Q1D; ++k) {
513c532df63SYohann     V[k] = 0.0;
51418d499f1SYohann     if (tidx < Q1D && tidy < Q1D)
515ab213215SJeremy L Thompson       for (int i = 0; i < P1D; ++i)
516ab213215SJeremy L Thompson         V[k] += B[i + k*P1D] * U[i]; // Contract z direction
517c532df63SYohann   }
51818d499f1SYohann   for (int k = Q1D; k < P1D; ++k)
51918d499f1SYohann     V[k] = 0.0;
520c532df63SYohann }
521c532df63SYohann 
522ab213215SJeremy L Thompson //------------------------------------------------------------------------------
523ab213215SJeremy L Thompson // 3D transpose tensor contract z
524ab213215SJeremy L Thompson //------------------------------------------------------------------------------
525c532df63SYohann inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
526698ebc35SYohann Dudouit                                             const int tidy, const int tidz,
52718d499f1SYohann                                             const CeedScalar *U,
52818d499f1SYohann                                             const CeedScalar *B,
52918d499f1SYohann                                             CeedScalar *V) {
53018d499f1SYohann   for (int k = 0; k < P1D; ++k) {
531c532df63SYohann     V[k] = 0.0;
53218d499f1SYohann     if (tidx < Q1D && tidy < Q1D)
533ab213215SJeremy L Thompson       for (int i = 0; i < Q1D; ++i)
534ab213215SJeremy L Thompson         V[k] += B[k + i*P1D] * U[i]; // Contract z direction
535c532df63SYohann   }
53618d499f1SYohann   for (int k = P1D; k < Q1D; ++k)
53718d499f1SYohann     V[k] = 0.0;
538c532df63SYohann }
539c532df63SYohann 
540ab213215SJeremy L Thompson //------------------------------------------------------------------------------
541ab213215SJeremy L Thompson // 3D transpose tensor contract y
542ab213215SJeremy L Thompson //------------------------------------------------------------------------------
543c532df63SYohann inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
544698ebc35SYohann Dudouit                                             const int tidy, const int tidz,
54518d499f1SYohann                                             const CeedScalar *U,
54618d499f1SYohann                                             const CeedScalar *B,
54718d499f1SYohann                                             CeedScalar *V) {
548c532df63SYohann   for (int k = 0; k < P1D; ++k) {
54918d499f1SYohann     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
550c532df63SYohann     __syncthreads();
551c532df63SYohann     V[k] = 0.0;
55218d499f1SYohann     if (tidx < Q1D && tidy < P1D)
553ab213215SJeremy L Thompson       for (int i = 0; i < Q1D; ++i)
55418d499f1SYohann         V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
555c532df63SYohann     __syncthreads();
556c532df63SYohann   }
557c532df63SYohann }
558c532df63SYohann 
559ab213215SJeremy L Thompson //------------------------------------------------------------------------------
560ab213215SJeremy L Thompson // 3D transpose tensor contract x
561ab213215SJeremy L Thompson //------------------------------------------------------------------------------
562c532df63SYohann inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
563698ebc35SYohann Dudouit                                             const int tidy, const int tidz,
56418d499f1SYohann                                             const CeedScalar *U,
56518d499f1SYohann                                             const CeedScalar *B,
56618d499f1SYohann                                             CeedScalar *V) {
567c532df63SYohann   for (int k = 0; k < P1D; ++k) {
56818d499f1SYohann     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
569c532df63SYohann     __syncthreads();
570c532df63SYohann     V[k] = 0.0;
57118d499f1SYohann     if (tidx < P1D && tidy < P1D)
572ab213215SJeremy L Thompson       for (int i = 0; i < Q1D; ++i)
57318d499f1SYohann         V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
574c532df63SYohann     __syncthreads();
575c532df63SYohann   }
576c532df63SYohann }
577c532df63SYohann 
578ab213215SJeremy L Thompson //------------------------------------------------------------------------------
579ab213215SJeremy L Thompson // 3D interpolate to quadrature points
580ab213215SJeremy L Thompson //------------------------------------------------------------------------------
581c532df63SYohann inline __device__ void interp3d(const CeedInt nelem, const int transpose,
5827f823360Sjeremylt                                 const CeedScalar *c_B,
5837f823360Sjeremylt                                 const CeedScalar *__restrict__ d_U,
584c532df63SYohann                                 CeedScalar *__restrict__ d_V,
585c532df63SYohann                                 CeedScalar *slice) {
58618d499f1SYohann   CeedScalar r_V[T1D];
58718d499f1SYohann   CeedScalar r_t[T1D];
588c532df63SYohann 
589c532df63SYohann   const int tidx = threadIdx.x;
590c532df63SYohann   const int tidy = threadIdx.y;
591698ebc35SYohann Dudouit   const int tidz = threadIdx.z;
592698ebc35SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
593698ebc35SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
594698ebc35SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
595c532df63SYohann 
596698ebc35SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
597698ebc35SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
59818d499f1SYohann     for (int i = 0; i < T1D; ++i) {
599c532df63SYohann       r_V[i] = 0.0;
600c532df63SYohann       r_t[i] = 0.0;
601c532df63SYohann     }
602c532df63SYohann     if (!transpose) {
603c532df63SYohann       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
604698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
605698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
606698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
607c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
608c532df63SYohann     } else {
609c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
610698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
611698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
612698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
613c532df63SYohann       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
614c532df63SYohann     }
615c532df63SYohann   }
616c532df63SYohann }
617c532df63SYohann 
618ab213215SJeremy L Thompson //------------------------------------------------------------------------------
619ab213215SJeremy L Thompson // 3D derivatives at quadrature points
620ab213215SJeremy L Thompson //------------------------------------------------------------------------------
621c532df63SYohann inline __device__ void grad3d(const CeedInt nelem, const int transpose,
622c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
6237f823360Sjeremylt                               const CeedScalar *__restrict__ d_U,
6247f823360Sjeremylt                               CeedScalar *__restrict__ d_V,
625c532df63SYohann                               CeedScalar *slice) {
626ab213215SJeremy L Thompson   // Use P1D for one of these
62718d499f1SYohann   CeedScalar r_U[T1D];
62818d499f1SYohann   CeedScalar r_V[T1D];
62918d499f1SYohann   CeedScalar r_t[T1D];
630c532df63SYohann 
631c532df63SYohann   const int tidx = threadIdx.x;
632c532df63SYohann   const int tidy = threadIdx.y;
633698ebc35SYohann Dudouit   const int tidz = threadIdx.z;
634698ebc35SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
635698ebc35SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
636698ebc35SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
637c532df63SYohann   int dim;
638c532df63SYohann 
639698ebc35SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
640698ebc35SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
64118d499f1SYohann     for (int i = 0; i < T1D; ++i) {
64218d499f1SYohann       r_U[i] = 0.0;
64318d499f1SYohann       r_V[i] = 0.0;
64418d499f1SYohann       r_t[i] = 0.0;
64518d499f1SYohann     }
646c532df63SYohann     if (!transpose) {
647c532df63SYohann       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
648698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
649698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
650698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
651c532df63SYohann       dim = 0;
652c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
653698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
654698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
655698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
656c532df63SYohann       dim = 1;
657c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
658698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
659698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
660698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
661c532df63SYohann       dim = 2;
662c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
663c532df63SYohann     } else {
664c532df63SYohann       dim = 0;
665c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
666698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
667698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
668698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
669c532df63SYohann       dim = 1;
670c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
671698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
672698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
673698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
674c532df63SYohann       add(r_V, r_t);
675c532df63SYohann       dim = 2;
676c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
677698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
678698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
679698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
680c532df63SYohann       add(r_V, r_t);
681c532df63SYohann       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
682c532df63SYohann     }
683c532df63SYohann   }
684c532df63SYohann }
685c532df63SYohann 
686ab213215SJeremy L Thompson //------------------------------------------------------------------------------
687ab213215SJeremy L Thompson // 3D quadrature weights
688ab213215SJeremy L Thompson //------------------------------------------------------------------------------
689ab213215SJeremy L Thompson __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
690ab213215SJeremy L Thompson                          CeedScalar *w) {
691ab213215SJeremy L Thompson   const int i = threadIdx.x;
692ab213215SJeremy L Thompson   const int j = threadIdx.y;
693ab213215SJeremy L Thompson   const int k = threadIdx.z;
694ab213215SJeremy L Thompson   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
695ab213215SJeremy L Thompson   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
696ab213215SJeremy L Thompson     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
697ab213215SJeremy L Thompson     w[ind] = weight;
698ab213215SJeremy L Thompson   }
699ab213215SJeremy L Thompson }
700ab213215SJeremy L Thompson 
701ab213215SJeremy L Thompson //------------------------------------------------------------------------------
702ab213215SJeremy L Thompson // Basis kernels
703ab213215SJeremy L Thompson //------------------------------------------------------------------------------
704ab213215SJeremy L Thompson 
705ab213215SJeremy L Thompson //------------------------------------------------------------------------------
706ab213215SJeremy L Thompson // Interp kernel by dim
707ab213215SJeremy L Thompson //------------------------------------------------------------------------------
708c532df63SYohann extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
7097f823360Sjeremylt                                   const CeedScalar *c_B,
7107f823360Sjeremylt                                   const CeedScalar *__restrict__ d_U,
711c532df63SYohann                                   CeedScalar *__restrict__ d_V) {
712074be161SYohann Dudouit   extern __shared__ double slice[];
713c532df63SYohann   if (BASIS_DIM == 1) {
714c532df63SYohann     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
715c532df63SYohann   } else if (BASIS_DIM == 2) {
716c532df63SYohann     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
717c532df63SYohann   } else if (BASIS_DIM == 3) {
718c532df63SYohann     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
719c532df63SYohann   }
720c532df63SYohann }
721c532df63SYohann 
722ab213215SJeremy L Thompson //------------------------------------------------------------------------------
723ab213215SJeremy L Thompson // Grad kernel by dim
724ab213215SJeremy L Thompson //------------------------------------------------------------------------------
725c532df63SYohann extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
726c532df63SYohann                                 const CeedScalar *c_B, const CeedScalar *c_G,
7277f823360Sjeremylt                                 const CeedScalar *__restrict__ d_U,
7287f823360Sjeremylt                                 CeedScalar *__restrict__ d_V) {
729074be161SYohann Dudouit   extern __shared__ double slice[];
730c532df63SYohann   if (BASIS_DIM == 1) {
731c532df63SYohann     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
732c532df63SYohann   } else if (BASIS_DIM == 2) {
733c532df63SYohann     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
734c532df63SYohann   } else if (BASIS_DIM == 3) {
735c532df63SYohann     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
736c532df63SYohann   }
737c532df63SYohann }
738c532df63SYohann 
739ab213215SJeremy L Thompson //------------------------------------------------------------------------------
740ab213215SJeremy L Thompson // Weight kernels by dim
741ab213215SJeremy L Thompson //------------------------------------------------------------------------------
742c532df63SYohann extern "C" __global__ void weight(const CeedInt nelem,
7437f823360Sjeremylt                                   const CeedScalar *__restrict__ qweight1d,
7447f823360Sjeremylt                                   CeedScalar *__restrict__ v) {
745c532df63SYohann   if (BASIS_DIM == 1) {
746c532df63SYohann     weight1d(nelem, qweight1d, v);
747c532df63SYohann   } else if (BASIS_DIM == 2) {
748c532df63SYohann     weight2d(nelem, qweight1d, v);
749c532df63SYohann   } else if (BASIS_DIM == 3) {
750c532df63SYohann     weight3d(nelem, qweight1d, v);
751c532df63SYohann   }
752c532df63SYohann }
753c532df63SYohann 
754c532df63SYohann );
755cb0b5415Sjeremylt // *INDENT-ON*
756c532df63SYohann 
757ab213215SJeremy L Thompson //------------------------------------------------------------------------------
758ab213215SJeremy L Thompson // Device initalization
759ab213215SJeremy L Thompson //------------------------------------------------------------------------------
760c532df63SYohann int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
761c532df63SYohann                        CeedScalar **c_B);
762c532df63SYohann int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
7637f823360Sjeremylt                            CeedInt Q1d, CeedScalar **c_B_ptr,
7647f823360Sjeremylt                            CeedScalar **c_G_ptr);
765c532df63SYohann 
766ab213215SJeremy L Thompson //------------------------------------------------------------------------------
767ab213215SJeremy L Thompson // Apply basis
768ab213215SJeremy L Thompson //------------------------------------------------------------------------------
769c532df63SYohann int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem,
770c532df63SYohann                                      CeedTransposeMode tmode,
7717f823360Sjeremylt                                      CeedEvalMode emode, CeedVector u,
7727f823360Sjeremylt                                      CeedVector v) {
773c532df63SYohann   int ierr;
774c532df63SYohann   Ceed ceed;
775e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
776c532df63SYohann   Ceed_Cuda_shared *ceed_Cuda;
777e15f9bd0SJeremy L Thompson   CeedGetData(ceed, &ceed_Cuda); CeedChkBackend(ierr);
778c532df63SYohann   CeedBasis_Cuda_shared *data;
779e15f9bd0SJeremy L Thompson   CeedBasisGetData(basis, &data); CeedChkBackend(ierr);
780c532df63SYohann   const CeedInt transpose = tmode == CEED_TRANSPOSE;
7814247ecf3SYohann Dudouit   CeedInt dim, ncomp;
782e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr);
783e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr);
784c532df63SYohann 
785ab213215SJeremy L Thompson   // Read vectors
786c532df63SYohann   const CeedScalar *d_u;
787c532df63SYohann   CeedScalar *d_v;
788c532df63SYohann   if (emode != CEED_EVAL_WEIGHT) {
789e15f9bd0SJeremy L Thompson     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChkBackend(ierr);
790c532df63SYohann   }
791e15f9bd0SJeremy L Thompson   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr);
792c532df63SYohann 
793ab213215SJeremy L Thompson   // Clear v for transpose mode
794c532df63SYohann   if (tmode == CEED_TRANSPOSE) {
795c532df63SYohann     CeedInt length;
796e15f9bd0SJeremy L Thompson     ierr = CeedVectorGetLength(v, &length); CeedChkBackend(ierr);
797e15f9bd0SJeremy L Thompson     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChkBackend(ierr);
798c532df63SYohann   }
799ab213215SJeremy L Thompson 
800ab213215SJeremy L Thompson   // Apply basis operation
801ab213215SJeremy L Thompson   switch (emode) {
802ab213215SJeremy L Thompson   case CEED_EVAL_INTERP: {
803c532df63SYohann     CeedInt P1d, Q1d;
804e15f9bd0SJeremy L Thompson     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr);
805e15f9bd0SJeremy L Thompson     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr);
80618d499f1SYohann     CeedInt thread1d = CeedIntMax(Q1d, P1d);
807c532df63SYohann     ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
808e15f9bd0SJeremy L Thompson     CeedChkBackend(ierr);
809cb0b5415Sjeremylt     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
810ccf0fe6fSjeremylt                           &d_u, &d_v
811ccf0fe6fSjeremylt                          };
8124d537eeaSYohann     if (dim == 1) {
813d94769d2SYohann Dudouit       CeedInt elemsPerBlock = 32;
8144d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8154d537eeaSYohann                                              ? 1 : 0 );
81618d499f1SYohann       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
81718d499f1SYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, 1,
8184d537eeaSYohann                                         elemsPerBlock, sharedMem,
819e15f9bd0SJeremy L Thompson                                         interpargs); CeedChkBackend(ierr);
820074be161SYohann Dudouit     } else if (dim == 2) {
8214247ecf3SYohann Dudouit       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
8220f70cdf6SJeremy L Thompson       // elemsPerBlock must be at least 1
82318d499f1SYohann       CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1);
8244d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8254d537eeaSYohann                                              ? 1 : 0 );
82618d499f1SYohann       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
82718d499f1SYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, thread1d,
8284d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
829e15f9bd0SJeremy L Thompson                                         interpargs); CeedChkBackend(ierr);
830074be161SYohann Dudouit     } else if (dim == 3) {
8313f63d318SYohann Dudouit       CeedInt elemsPerBlock = 1;
8324d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8334d537eeaSYohann                                              ? 1 : 0 );
83418d499f1SYohann       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
83518d499f1SYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, thread1d,
8364d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
837e15f9bd0SJeremy L Thompson                                         interpargs); CeedChkBackend(ierr);
838074be161SYohann Dudouit     }
839ab213215SJeremy L Thompson   } break;
840ab213215SJeremy L Thompson   case CEED_EVAL_GRAD: {
841c532df63SYohann     CeedInt P1d, Q1d;
842e15f9bd0SJeremy L Thompson     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr);
843e15f9bd0SJeremy L Thompson     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr);
84418d499f1SYohann     CeedInt thread1d = CeedIntMax(Q1d, P1d);
845c532df63SYohann     ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
846c532df63SYohann                                   Q1d, &data->c_B, &data->c_G);
847e15f9bd0SJeremy L Thompson     CeedChkBackend(ierr);
848cb0b5415Sjeremylt     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
849ccf0fe6fSjeremylt                         &data->c_G, &d_u, &d_v
850ccf0fe6fSjeremylt                        };
8514d537eeaSYohann     if (dim == 1) {
852d94769d2SYohann Dudouit       CeedInt elemsPerBlock = 32;
8534d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8544d537eeaSYohann                                              ? 1 : 0 );
85518d499f1SYohann       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
85618d499f1SYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, 1,
857ab213215SJeremy L Thompson                                         elemsPerBlock, sharedMem, gradargs);
858e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
859074be161SYohann Dudouit     } else if (dim == 2) {
8604247ecf3SYohann Dudouit       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
8610f70cdf6SJeremy L Thompson       // elemsPerBlock must be at least 1
86218d499f1SYohann       CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1);
8634d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8644d537eeaSYohann                                              ? 1 : 0 );
86518d499f1SYohann       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
86618d499f1SYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, thread1d,
8674d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
868e15f9bd0SJeremy L Thompson                                         gradargs); CeedChkBackend(ierr);
869074be161SYohann Dudouit     } else if (dim == 3) {
8703f63d318SYohann Dudouit       CeedInt elemsPerBlock = 1;
8714d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8724d537eeaSYohann                                              ? 1 : 0 );
87318d499f1SYohann       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
87418d499f1SYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, thread1d,
8754d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
876e15f9bd0SJeremy L Thompson                                         gradargs); CeedChkBackend(ierr);
877074be161SYohann Dudouit     }
878ab213215SJeremy L Thompson   } break;
879ab213215SJeremy L Thompson   case CEED_EVAL_WEIGHT: {
880074be161SYohann Dudouit     CeedInt Q1d;
881e15f9bd0SJeremy L Thompson     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr);
882c532df63SYohann     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
883074be161SYohann Dudouit     if (dim == 1) {
884074be161SYohann Dudouit       const CeedInt elemsPerBlock = 32/Q1d;
8854d537eeaSYohann       const CeedInt gridsize = nelem/elemsPerBlock + ( (
8864d537eeaSYohann                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
8877f823360Sjeremylt       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d,
8887f823360Sjeremylt                                   elemsPerBlock, 1, weightargs);
889e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
890074be161SYohann Dudouit     } else if (dim == 2) {
891717ff8a3SYohann Dudouit       const CeedInt optElems = 32/(Q1d*Q1d);
892717ff8a3SYohann Dudouit       const CeedInt elemsPerBlock = optElems>0?optElems:1;
8934d537eeaSYohann       const CeedInt gridsize = nelem/elemsPerBlock + ( (
8944d537eeaSYohann                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
8954d537eeaSYohann       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d,
8964d537eeaSYohann                                   elemsPerBlock, weightargs);
897e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
898074be161SYohann Dudouit     } else if (dim == 3) {
899074be161SYohann Dudouit       const CeedInt gridsize = nelem;
9004d537eeaSYohann       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
9014d537eeaSYohann                                   weightargs);
902e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
903074be161SYohann Dudouit     }
904ab213215SJeremy L Thompson   } break;
905ab213215SJeremy L Thompson   // LCOV_EXCL_START
906ab213215SJeremy L Thompson   // Evaluate the divergence to/from the quadrature points
907ab213215SJeremy L Thompson   case CEED_EVAL_DIV:
908e15f9bd0SJeremy L Thompson     return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
909ab213215SJeremy L Thompson   // Evaluate the curl to/from the quadrature points
910ab213215SJeremy L Thompson   case CEED_EVAL_CURL:
911e15f9bd0SJeremy L Thompson     return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
912ab213215SJeremy L Thompson   // Take no action, BasisApply should not have been called
913ab213215SJeremy L Thompson   case CEED_EVAL_NONE:
914e15f9bd0SJeremy L Thompson     return CeedError(ceed, CEED_ERROR_BACKEND,
915ab213215SJeremy L Thompson                      "CEED_EVAL_NONE does not make sense in this context");
916ab213215SJeremy L Thompson     // LCOV_EXCL_STOP
917c532df63SYohann   }
918c532df63SYohann 
919ab213215SJeremy L Thompson   // Restore vectors
920c532df63SYohann   if (emode != CEED_EVAL_WEIGHT) {
921e15f9bd0SJeremy L Thompson     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChkBackend(ierr);
922c532df63SYohann   }
923e15f9bd0SJeremy L Thompson   ierr = CeedVectorRestoreArray(v, &d_v); CeedChkBackend(ierr);
924e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
925c532df63SYohann }
926c532df63SYohann 
927ab213215SJeremy L Thompson //------------------------------------------------------------------------------
928ab213215SJeremy L Thompson // Destroy basis
929ab213215SJeremy L Thompson //------------------------------------------------------------------------------
930c532df63SYohann static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
931c532df63SYohann   int ierr;
932c532df63SYohann   Ceed ceed;
933e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
934c532df63SYohann 
935c532df63SYohann   CeedBasis_Cuda_shared *data;
936e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetData(basis, &data); CeedChkBackend(ierr);
937c532df63SYohann 
938c532df63SYohann   CeedChk_Cu(ceed, cuModuleUnload(data->module));
939c532df63SYohann 
940c532df63SYohann   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr);
941c532df63SYohann   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr);
942c532df63SYohann   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr);
9431958eb7cSJeremy L Thompson   ierr = cudaFree(data->d_collograd1d); CeedChk_Cu(ceed, ierr);
944c532df63SYohann 
945e15f9bd0SJeremy L Thompson   ierr = CeedFree(&data); CeedChkBackend(ierr);
946c532df63SYohann 
947e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
948c532df63SYohann }
949c532df63SYohann 
950ab213215SJeremy L Thompson //------------------------------------------------------------------------------
951ab213215SJeremy L Thompson // Create tensor basis
952ab213215SJeremy L Thompson //------------------------------------------------------------------------------
953c532df63SYohann int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
954c532df63SYohann                                         const CeedScalar *interp1d,
955c532df63SYohann                                         const CeedScalar *grad1d,
956c532df63SYohann                                         const CeedScalar *qref1d,
957c532df63SYohann                                         const CeedScalar *qweight1d,
958c532df63SYohann                                         CeedBasis basis) {
959c532df63SYohann   int ierr;
960c532df63SYohann   Ceed ceed;
961e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
962c532df63SYohann   CeedBasis_Cuda_shared *data;
963e15f9bd0SJeremy L Thompson   ierr = CeedCalloc(1, &data); CeedChkBackend(ierr);
964c532df63SYohann 
965ab213215SJeremy L Thompson   // Copy basis data to GPU
966c532df63SYohann   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
967c532df63SYohann   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr);
968c532df63SYohann   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
969c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
970c532df63SYohann 
971c532df63SYohann   const CeedInt iBytes = qBytes * P1d;
972c532df63SYohann   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr);
973c532df63SYohann   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
974c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
975c532df63SYohann 
976c532df63SYohann   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr);
977c532df63SYohann   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
978c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
979c532df63SYohann 
980ab213215SJeremy L Thompson   // Compute collocated gradient and copy to GPU
981ac421f39SYohann   data->d_collograd1d = NULL;
982ac421f39SYohann   if (dim == 3 && Q1d >= P1d) {
983ac421f39SYohann     CeedScalar *collograd1d;
984e15f9bd0SJeremy L Thompson     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChkBackend(ierr);
985e15f9bd0SJeremy L Thompson     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChkBackend(ierr);
986ac421f39SYohann     ierr = cudaMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
987ac421f39SYohann     CeedChk_Cu(ceed, ierr);
988ac421f39SYohann     ierr = cudaMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
989ac421f39SYohann                       cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
990e15f9bd0SJeremy L Thompson     ierr = CeedFree(&collograd1d); CeedChkBackend(ierr);
991ac421f39SYohann   }
992ac421f39SYohann 
993ab213215SJeremy L Thompson   // Compile basis kernels
994c532df63SYohann   CeedInt ncomp;
995e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr);
99618d499f1SYohann   ierr = CeedCompileCuda(ceed, kernelsShared, &data->module, 8,
997c532df63SYohann                          "Q1D", Q1d,
998c532df63SYohann                          "P1D", P1d,
99918d499f1SYohann                          "T1D", CeedIntMax(Q1d, P1d),
1000c532df63SYohann                          "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
1001c532df63SYohann                              Q1d : P1d, dim),
1002c532df63SYohann                          "BASIS_DIM", dim,
1003c532df63SYohann                          "BASIS_NCOMP", ncomp,
1004c532df63SYohann                          "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
1005c532df63SYohann                          "BASIS_NQPT", CeedIntPow(Q1d, dim)
1006e15f9bd0SJeremy L Thompson                         ); CeedChkBackend(ierr);
10074a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp);
1008e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
10094a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad);
1010e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
10114a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight);
1012e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
1013c532df63SYohann 
1014e15f9bd0SJeremy L Thompson   ierr = CeedBasisSetData(basis, data); CeedChkBackend(ierr);
1015ab213215SJeremy L Thompson 
1016ab213215SJeremy L Thompson   // Register backend functions
1017c532df63SYohann   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
1018c532df63SYohann                                 CeedBasisApplyTensor_Cuda_shared);
1019e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
1020c532df63SYohann   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
1021e15f9bd0SJeremy L Thompson                                 CeedBasisDestroy_Cuda_shared); CeedChkBackend(ierr);
1022e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
1023c532df63SYohann }
1024ab213215SJeremy L Thompson //------------------------------------------------------------------------------
1025