xref: /libCEED/rust/libceed-sys/c-src/backends/hip-shared/ceed-hip-shared-basis.c (revision 7d8d0e25636a94a27ff75b3dec09737e24cdb0fe)
1*7d8d0e25Snbeams // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2*7d8d0e25Snbeams // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3*7d8d0e25Snbeams // All Rights reserved. See files LICENSE and NOTICE for details.
4*7d8d0e25Snbeams //
5*7d8d0e25Snbeams // This file is part of CEED, a collection of benchmarks, miniapps, software
6*7d8d0e25Snbeams // libraries and APIs for efficient high-order finite element and spectral
7*7d8d0e25Snbeams // element discretizations for exascale applications. For more information and
8*7d8d0e25Snbeams // source code availability see http://github.com/ceed.
9*7d8d0e25Snbeams //
10*7d8d0e25Snbeams // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11*7d8d0e25Snbeams // a collaborative effort of two U.S. Department of Energy organizations (Office
12*7d8d0e25Snbeams // of Science and the National Nuclear Security Administration) responsible for
13*7d8d0e25Snbeams // the planning and preparation of a capable exascale ecosystem, including
14*7d8d0e25Snbeams // software, applications, hardware, advanced system engineering and early
15*7d8d0e25Snbeams // testbed platforms, in support of the nation's exascale computing imperative.
16*7d8d0e25Snbeams 
17*7d8d0e25Snbeams #include "ceed-hip-shared.h"
18*7d8d0e25Snbeams #include "../hip/ceed-hip-compile.h"
19*7d8d0e25Snbeams 
20*7d8d0e25Snbeams //------------------------------------------------------------------------------
21*7d8d0e25Snbeams // Shared mem kernels
22*7d8d0e25Snbeams //------------------------------------------------------------------------------
23*7d8d0e25Snbeams // *INDENT-OFF*
24*7d8d0e25Snbeams static const char *kernelsShared = QUOTE(
25*7d8d0e25Snbeams 
26*7d8d0e25Snbeams //------------------------------------------------------------------------------
27*7d8d0e25Snbeams // Sum input into output
28*7d8d0e25Snbeams //------------------------------------------------------------------------------
29*7d8d0e25Snbeams inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
30*7d8d0e25Snbeams   for (int i = 0; i < P1D; i++)
31*7d8d0e25Snbeams     r_V[i] += r_U[i];
32*7d8d0e25Snbeams }
33*7d8d0e25Snbeams 
34*7d8d0e25Snbeams //------------------------------------------------------------------------------
35*7d8d0e25Snbeams // 1D
36*7d8d0e25Snbeams //------------------------------------------------------------------------------
37*7d8d0e25Snbeams 
38*7d8d0e25Snbeams //------------------------------------------------------------------------------
39*7d8d0e25Snbeams // Read DoFs
40*7d8d0e25Snbeams //------------------------------------------------------------------------------
41*7d8d0e25Snbeams inline __device__ void readDofs1d(const int elem, const int tidx,
42*7d8d0e25Snbeams                                   const int tidy, const int tidz,const int comp,
43*7d8d0e25Snbeams                                   const int nelem, const CeedScalar *d_U,
44*7d8d0e25Snbeams                                   CeedScalar *slice) {
45*7d8d0e25Snbeams   for (int i = 0; i < P1D; i++)
46*7d8d0e25Snbeams     slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem];
47*7d8d0e25Snbeams   for (int i = P1D; i < Q1D; i++)
48*7d8d0e25Snbeams     slice[i + tidz*T1D] = 0.0;
49*7d8d0e25Snbeams }
50*7d8d0e25Snbeams 
51*7d8d0e25Snbeams //------------------------------------------------------------------------------
52*7d8d0e25Snbeams // Write DoFs
53*7d8d0e25Snbeams //------------------------------------------------------------------------------
54*7d8d0e25Snbeams inline __device__ void writeDofs1d(const int elem, const int tidx,
55*7d8d0e25Snbeams                                    const int tidy, const int comp,
56*7d8d0e25Snbeams                                    const int nelem, const CeedScalar &r_V,
57*7d8d0e25Snbeams                                    CeedScalar *d_V) {
58*7d8d0e25Snbeams   if (tidx<P1D)
59*7d8d0e25Snbeams     d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V;
60*7d8d0e25Snbeams }
61*7d8d0e25Snbeams 
62*7d8d0e25Snbeams //------------------------------------------------------------------------------
63*7d8d0e25Snbeams // Read quadrature point data
64*7d8d0e25Snbeams //------------------------------------------------------------------------------
65*7d8d0e25Snbeams inline __device__ void readQuads1d(const int elem, const int tidx,
66*7d8d0e25Snbeams                                    const int tidy, const int tidz, const int comp,
67*7d8d0e25Snbeams                                    const int dim, const int nelem,
68*7d8d0e25Snbeams                                    const CeedScalar *d_U, CeedScalar *slice) {
69*7d8d0e25Snbeams   for (int i = 0; i < Q1D; i++)
70*7d8d0e25Snbeams     slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
71*7d8d0e25Snbeams                             dim*BASIS_NCOMP*nelem*Q1D];
72*7d8d0e25Snbeams   for (int i = Q1D; i < P1D; i++)
73*7d8d0e25Snbeams     slice[i + tidz*T1D] = 0.0;
74*7d8d0e25Snbeams }
75*7d8d0e25Snbeams 
76*7d8d0e25Snbeams //------------------------------------------------------------------------------
77*7d8d0e25Snbeams // Write quadrature point data
78*7d8d0e25Snbeams //------------------------------------------------------------------------------
79*7d8d0e25Snbeams inline __device__ void writeQuads1d(const int elem, const int tidx,
80*7d8d0e25Snbeams                                     const int tidy, const int comp,
81*7d8d0e25Snbeams                                     const int dim, const int nelem,
82*7d8d0e25Snbeams                                     const CeedScalar &r_V, CeedScalar *d_V) {
83*7d8d0e25Snbeams   if (tidx<Q1D)
84*7d8d0e25Snbeams     d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
85*7d8d0e25Snbeams }
86*7d8d0e25Snbeams 
87*7d8d0e25Snbeams //------------------------------------------------------------------------------
88*7d8d0e25Snbeams // 1D tensor contraction
89*7d8d0e25Snbeams //------------------------------------------------------------------------------
90*7d8d0e25Snbeams inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
91*7d8d0e25Snbeams                                    const int tidy, const int tidz,
92*7d8d0e25Snbeams                                    const CeedScalar &U, const CeedScalar *B,
93*7d8d0e25Snbeams                                    CeedScalar &V) {
94*7d8d0e25Snbeams   V = 0.0;
95*7d8d0e25Snbeams   for (int i = 0; i < P1D; ++i)
96*7d8d0e25Snbeams     V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction
97*7d8d0e25Snbeams }
98*7d8d0e25Snbeams 
99*7d8d0e25Snbeams //------------------------------------------------------------------------------
100*7d8d0e25Snbeams // 1D transpose tensor contraction
101*7d8d0e25Snbeams //------------------------------------------------------------------------------
102*7d8d0e25Snbeams inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
103*7d8d0e25Snbeams     const int tidy, const int tidz,
104*7d8d0e25Snbeams     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
105*7d8d0e25Snbeams   V = 0.0;
106*7d8d0e25Snbeams   for (int i = 0; i < Q1D; ++i)
107*7d8d0e25Snbeams     V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction
108*7d8d0e25Snbeams }
109*7d8d0e25Snbeams 
110*7d8d0e25Snbeams //------------------------------------------------------------------------------
111*7d8d0e25Snbeams // 1D interpolate to quadrature points
112*7d8d0e25Snbeams //------------------------------------------------------------------------------
113*7d8d0e25Snbeams inline __device__ void interp1d(const CeedInt nelem, const int transpose,
114*7d8d0e25Snbeams                                 const CeedScalar *c_B,
115*7d8d0e25Snbeams                                 const CeedScalar *__restrict__ d_U,
116*7d8d0e25Snbeams                                 CeedScalar *__restrict__ d_V,
117*7d8d0e25Snbeams                                 CeedScalar *slice) {
118*7d8d0e25Snbeams   CeedScalar r_V;
119*7d8d0e25Snbeams   CeedScalar r_t;
120*7d8d0e25Snbeams 
121*7d8d0e25Snbeams   const int tidx = threadIdx.x;
122*7d8d0e25Snbeams   const int tidy = threadIdx.y;
123*7d8d0e25Snbeams   const int tidz = threadIdx.z;
124*7d8d0e25Snbeams 
125*7d8d0e25Snbeams 
126*7d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
127*7d8d0e25Snbeams        elem += gridDim.x*blockDim.z) {
128*7d8d0e25Snbeams     for (int comp = 0; comp < BASIS_NCOMP; comp++) {
129*7d8d0e25Snbeams       if (!transpose) {
130*7d8d0e25Snbeams         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
131*7d8d0e25Snbeams         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
132*7d8d0e25Snbeams         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
133*7d8d0e25Snbeams       } else {
134*7d8d0e25Snbeams         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
135*7d8d0e25Snbeams         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
136*7d8d0e25Snbeams         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
137*7d8d0e25Snbeams       }
138*7d8d0e25Snbeams     }
139*7d8d0e25Snbeams   }
140*7d8d0e25Snbeams }
141*7d8d0e25Snbeams 
142*7d8d0e25Snbeams //------------------------------------------------------------------------------
143*7d8d0e25Snbeams // 1D derivatives at quadrature points
144*7d8d0e25Snbeams //------------------------------------------------------------------------------
145*7d8d0e25Snbeams inline __device__ void grad1d(const CeedInt nelem, const int transpose,
146*7d8d0e25Snbeams                               const CeedScalar *c_B, const CeedScalar *c_G,
147*7d8d0e25Snbeams                               const CeedScalar *__restrict__ d_U,
148*7d8d0e25Snbeams                               CeedScalar *__restrict__ d_V,
149*7d8d0e25Snbeams                               CeedScalar *slice) {
150*7d8d0e25Snbeams   CeedScalar r_U;
151*7d8d0e25Snbeams   CeedScalar r_V;
152*7d8d0e25Snbeams 
153*7d8d0e25Snbeams   const int tidx = threadIdx.x;
154*7d8d0e25Snbeams   const int tidy = threadIdx.y;
155*7d8d0e25Snbeams   const int tidz = threadIdx.z;
156*7d8d0e25Snbeams   int dim;
157*7d8d0e25Snbeams 
158*7d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
159*7d8d0e25Snbeams        elem += gridDim.x*blockDim.z) {
160*7d8d0e25Snbeams     for(int comp = 0; comp < BASIS_NCOMP; comp++) {
161*7d8d0e25Snbeams       if (!transpose) {
162*7d8d0e25Snbeams         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
163*7d8d0e25Snbeams         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
164*7d8d0e25Snbeams         dim = 0;
165*7d8d0e25Snbeams         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
166*7d8d0e25Snbeams       } else {
167*7d8d0e25Snbeams         dim = 0;
168*7d8d0e25Snbeams         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
169*7d8d0e25Snbeams         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
170*7d8d0e25Snbeams         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
171*7d8d0e25Snbeams       }
172*7d8d0e25Snbeams     }
173*7d8d0e25Snbeams   }
174*7d8d0e25Snbeams }
175*7d8d0e25Snbeams 
176*7d8d0e25Snbeams //------------------------------------------------------------------------------
177*7d8d0e25Snbeams // 1D Quadrature weights
178*7d8d0e25Snbeams //------------------------------------------------------------------------------
179*7d8d0e25Snbeams __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
180*7d8d0e25Snbeams                          CeedScalar *w) {
181*7d8d0e25Snbeams   const int tid = threadIdx.x;
182*7d8d0e25Snbeams   const CeedScalar weight = qweight1d[tid];
183*7d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
184*7d8d0e25Snbeams        elem += gridDim.x*blockDim.y) {
185*7d8d0e25Snbeams     const int ind = elem*Q1D + tid;
186*7d8d0e25Snbeams     w[ind] = weight;
187*7d8d0e25Snbeams   }
188*7d8d0e25Snbeams }
189*7d8d0e25Snbeams 
190*7d8d0e25Snbeams //------------------------------------------------------------------------------
191*7d8d0e25Snbeams // 2D
192*7d8d0e25Snbeams //------------------------------------------------------------------------------
193*7d8d0e25Snbeams 
194*7d8d0e25Snbeams //------------------------------------------------------------------------------
195*7d8d0e25Snbeams // Read DoFs
196*7d8d0e25Snbeams //------------------------------------------------------------------------------
197*7d8d0e25Snbeams inline __device__ void readDofs2d(const int elem, const int tidx,
198*7d8d0e25Snbeams                                   const int tidy, const int comp,
199*7d8d0e25Snbeams                                   const int nelem, const CeedScalar *d_U,
200*7d8d0e25Snbeams                                   CeedScalar &U) {
201*7d8d0e25Snbeams   U = (tidx<P1D && tidy<P1D) ?
202*7d8d0e25Snbeams       d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0;
203*7d8d0e25Snbeams }
204*7d8d0e25Snbeams 
205*7d8d0e25Snbeams //------------------------------------------------------------------------------
206*7d8d0e25Snbeams // Write DoFs
207*7d8d0e25Snbeams //------------------------------------------------------------------------------
208*7d8d0e25Snbeams inline __device__ void writeDofs2d(const int elem, const int tidx,
209*7d8d0e25Snbeams                                    const int tidy, const int comp,
210*7d8d0e25Snbeams                                    const int nelem, const CeedScalar &r_V,
211*7d8d0e25Snbeams                                    CeedScalar *d_V) {
212*7d8d0e25Snbeams   if (tidx<P1D && tidy<P1D)
213*7d8d0e25Snbeams     d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V;
214*7d8d0e25Snbeams }
215*7d8d0e25Snbeams 
216*7d8d0e25Snbeams //------------------------------------------------------------------------------
217*7d8d0e25Snbeams // Read quadrature point data
218*7d8d0e25Snbeams //------------------------------------------------------------------------------
219*7d8d0e25Snbeams inline __device__ void readQuads2d(const int elem, const int tidx,
220*7d8d0e25Snbeams                                    const int tidy, const int comp,
221*7d8d0e25Snbeams                                    const int dim, const int nelem,
222*7d8d0e25Snbeams                                    const CeedScalar *d_U, CeedScalar &U ) {
223*7d8d0e25Snbeams   U = (tidx<Q1D && tidy<Q1D) ?
224*7d8d0e25Snbeams       d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
225*7d8d0e25Snbeams       dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0;
226*7d8d0e25Snbeams }
227*7d8d0e25Snbeams 
228*7d8d0e25Snbeams //------------------------------------------------------------------------------
229*7d8d0e25Snbeams // Write quadrature point data
230*7d8d0e25Snbeams //------------------------------------------------------------------------------
231*7d8d0e25Snbeams inline __device__ void writeQuads2d(const int elem, const int tidx,
232*7d8d0e25Snbeams                                     const int tidy, const int comp,
233*7d8d0e25Snbeams                                     const int dim, const int nelem,
234*7d8d0e25Snbeams                                     const CeedScalar &r_V, CeedScalar *d_V) {
235*7d8d0e25Snbeams   if (tidx<Q1D && tidy<Q1D)
236*7d8d0e25Snbeams     d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
237*7d8d0e25Snbeams     dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V;
238*7d8d0e25Snbeams }
239*7d8d0e25Snbeams 
240*7d8d0e25Snbeams //------------------------------------------------------------------------------
241*7d8d0e25Snbeams // 2D tensor contraction x
242*7d8d0e25Snbeams //------------------------------------------------------------------------------
243*7d8d0e25Snbeams inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
244*7d8d0e25Snbeams                                    const int tidy, const int tidz,
245*7d8d0e25Snbeams                                    const CeedScalar &U, const CeedScalar *B,
246*7d8d0e25Snbeams                                    CeedScalar &V) {
247*7d8d0e25Snbeams   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
248*7d8d0e25Snbeams   __syncthreads();
249*7d8d0e25Snbeams   V = 0.0;
250*7d8d0e25Snbeams   if (tidx < Q1D)
251*7d8d0e25Snbeams     for (int i = 0; i < P1D; ++i)
252*7d8d0e25Snbeams       V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
253*7d8d0e25Snbeams   __syncthreads();
254*7d8d0e25Snbeams }
255*7d8d0e25Snbeams 
256*7d8d0e25Snbeams //------------------------------------------------------------------------------
257*7d8d0e25Snbeams // 2D tensor contraction y
258*7d8d0e25Snbeams //------------------------------------------------------------------------------
259*7d8d0e25Snbeams inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
260*7d8d0e25Snbeams                                    const int tidy, const int tidz,
261*7d8d0e25Snbeams                                    const CeedScalar &U, const CeedScalar *B,
262*7d8d0e25Snbeams                                    CeedScalar &V) {
263*7d8d0e25Snbeams   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
264*7d8d0e25Snbeams   __syncthreads();
265*7d8d0e25Snbeams   V = 0.0;
266*7d8d0e25Snbeams   if (tidy < Q1D)
267*7d8d0e25Snbeams     for (int i = 0; i < P1D; ++i)
268*7d8d0e25Snbeams       V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
269*7d8d0e25Snbeams   __syncthreads();
270*7d8d0e25Snbeams }
271*7d8d0e25Snbeams 
272*7d8d0e25Snbeams //------------------------------------------------------------------------------
273*7d8d0e25Snbeams // 2D transpose tensor contraction y
274*7d8d0e25Snbeams //------------------------------------------------------------------------------
275*7d8d0e25Snbeams inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
276*7d8d0e25Snbeams     const int tidy, const int tidz,
277*7d8d0e25Snbeams     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
278*7d8d0e25Snbeams   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
279*7d8d0e25Snbeams   __syncthreads();
280*7d8d0e25Snbeams   V = 0.0;
281*7d8d0e25Snbeams   if (tidy < P1D)
282*7d8d0e25Snbeams     for (int i = 0; i < Q1D; ++i)
283*7d8d0e25Snbeams       V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
284*7d8d0e25Snbeams   __syncthreads();
285*7d8d0e25Snbeams }
286*7d8d0e25Snbeams 
287*7d8d0e25Snbeams //------------------------------------------------------------------------------
288*7d8d0e25Snbeams // 2D transpose tensor contraction x
289*7d8d0e25Snbeams //------------------------------------------------------------------------------
290*7d8d0e25Snbeams inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
291*7d8d0e25Snbeams     const int tidy, const int tidz,
292*7d8d0e25Snbeams     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
293*7d8d0e25Snbeams   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
294*7d8d0e25Snbeams   __syncthreads();
295*7d8d0e25Snbeams   V = 0.0;
296*7d8d0e25Snbeams   if (tidx < P1D)
297*7d8d0e25Snbeams     for (int i = 0; i < Q1D; ++i)
298*7d8d0e25Snbeams       V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
299*7d8d0e25Snbeams   __syncthreads();
300*7d8d0e25Snbeams }
301*7d8d0e25Snbeams 
302*7d8d0e25Snbeams //------------------------------------------------------------------------------
303*7d8d0e25Snbeams // 2D interpolate to quadrature points
304*7d8d0e25Snbeams //------------------------------------------------------------------------------
305*7d8d0e25Snbeams inline __device__ void interp2d(const CeedInt nelem, const int transpose,
306*7d8d0e25Snbeams                                 const CeedScalar *c_B,
307*7d8d0e25Snbeams                                 const CeedScalar *__restrict__ d_U,
308*7d8d0e25Snbeams                                 CeedScalar *__restrict__ d_V,
309*7d8d0e25Snbeams                                 CeedScalar *slice) {
310*7d8d0e25Snbeams   CeedScalar r_V;
311*7d8d0e25Snbeams   CeedScalar r_t;
312*7d8d0e25Snbeams 
313*7d8d0e25Snbeams   const int tidx = threadIdx.x;
314*7d8d0e25Snbeams   const int tidy = threadIdx.y;
315*7d8d0e25Snbeams   const int tidz = threadIdx.z;
316*7d8d0e25Snbeams   const int blockElem = tidz/BASIS_NCOMP;
317*7d8d0e25Snbeams   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
318*7d8d0e25Snbeams   const int comp = tidz%BASIS_NCOMP;
319*7d8d0e25Snbeams 
320*7d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
321*7d8d0e25Snbeams        elem += gridDim.x*elemsPerBlock) {
322*7d8d0e25Snbeams     const int comp = tidz%BASIS_NCOMP;
323*7d8d0e25Snbeams     r_V = 0.0;
324*7d8d0e25Snbeams     r_t = 0.0;
325*7d8d0e25Snbeams     if (!transpose) {
326*7d8d0e25Snbeams       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
327*7d8d0e25Snbeams       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
328*7d8d0e25Snbeams       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
329*7d8d0e25Snbeams       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
330*7d8d0e25Snbeams     } else {
331*7d8d0e25Snbeams       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
332*7d8d0e25Snbeams       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
333*7d8d0e25Snbeams       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
334*7d8d0e25Snbeams       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
335*7d8d0e25Snbeams     }
336*7d8d0e25Snbeams   }
337*7d8d0e25Snbeams }
338*7d8d0e25Snbeams 
339*7d8d0e25Snbeams //------------------------------------------------------------------------------
340*7d8d0e25Snbeams // 2D derivatives at quadrature points
341*7d8d0e25Snbeams //------------------------------------------------------------------------------
342*7d8d0e25Snbeams inline __device__ void grad2d(const CeedInt nelem, const int transpose,
343*7d8d0e25Snbeams                               const CeedScalar *c_B, const CeedScalar *c_G,
344*7d8d0e25Snbeams                               const CeedScalar *__restrict__ d_U,
345*7d8d0e25Snbeams                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
346*7d8d0e25Snbeams   CeedScalar r_U;
347*7d8d0e25Snbeams   CeedScalar r_V;
348*7d8d0e25Snbeams   CeedScalar r_t;
349*7d8d0e25Snbeams 
350*7d8d0e25Snbeams   const int tidx = threadIdx.x;
351*7d8d0e25Snbeams   const int tidy = threadIdx.y;
352*7d8d0e25Snbeams   const int tidz = threadIdx.z;
353*7d8d0e25Snbeams   const int blockElem = tidz/BASIS_NCOMP;
354*7d8d0e25Snbeams   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
355*7d8d0e25Snbeams   const int comp = tidz%BASIS_NCOMP;
356*7d8d0e25Snbeams   int dim;
357*7d8d0e25Snbeams 
358*7d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
359*7d8d0e25Snbeams        elem += gridDim.x*elemsPerBlock) {
360*7d8d0e25Snbeams     if (!transpose) {
361*7d8d0e25Snbeams       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
362*7d8d0e25Snbeams       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
363*7d8d0e25Snbeams       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
364*7d8d0e25Snbeams       dim = 0;
365*7d8d0e25Snbeams       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
366*7d8d0e25Snbeams       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
367*7d8d0e25Snbeams       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
368*7d8d0e25Snbeams       dim = 1;
369*7d8d0e25Snbeams       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
370*7d8d0e25Snbeams     } else {
371*7d8d0e25Snbeams       dim = 0;
372*7d8d0e25Snbeams       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
373*7d8d0e25Snbeams       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
374*7d8d0e25Snbeams       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
375*7d8d0e25Snbeams       dim = 1;
376*7d8d0e25Snbeams       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
377*7d8d0e25Snbeams       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
378*7d8d0e25Snbeams       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
379*7d8d0e25Snbeams       r_V += r_U;
380*7d8d0e25Snbeams       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
381*7d8d0e25Snbeams     }
382*7d8d0e25Snbeams   }
383*7d8d0e25Snbeams }
384*7d8d0e25Snbeams 
385*7d8d0e25Snbeams //------------------------------------------------------------------------------
386*7d8d0e25Snbeams // 2D quadrature weights
387*7d8d0e25Snbeams //------------------------------------------------------------------------------
388*7d8d0e25Snbeams __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
389*7d8d0e25Snbeams                          CeedScalar *w) {
390*7d8d0e25Snbeams   const int i = threadIdx.x;
391*7d8d0e25Snbeams   const int j = threadIdx.y;
392*7d8d0e25Snbeams   const CeedScalar weight = qweight1d[i]*qweight1d[j];
393*7d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
394*7d8d0e25Snbeams        elem += gridDim.x*blockDim.z) {
395*7d8d0e25Snbeams     const int ind = elem*Q1D*Q1D + i + j*Q1D;
396*7d8d0e25Snbeams     w[ind] = weight;
397*7d8d0e25Snbeams   }
398*7d8d0e25Snbeams }
399*7d8d0e25Snbeams 
400*7d8d0e25Snbeams //------------------------------------------------------------------------------
401*7d8d0e25Snbeams // 3D
402*7d8d0e25Snbeams //------------------------------------------------------------------------------
403*7d8d0e25Snbeams 
404*7d8d0e25Snbeams //------------------------------------------------------------------------------
405*7d8d0e25Snbeams // Read DoFs
406*7d8d0e25Snbeams //------------------------------------------------------------------------------
407*7d8d0e25Snbeams inline __device__ void readDofs3d(const int elem, const int tidx,
408*7d8d0e25Snbeams                                   const int tidy, const int comp,
409*7d8d0e25Snbeams                                   const int nelem, const CeedScalar *d_U,
410*7d8d0e25Snbeams                                   CeedScalar *r_U) {
411*7d8d0e25Snbeams   for (int i = 0; i < P1D; i++)
412*7d8d0e25Snbeams     r_U[i] = (tidx < P1D && tidy < P1D) ?
413*7d8d0e25Snbeams               d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
414*7d8d0e25Snbeams                   comp*P1D*P1D*P1D*nelem] : 0.0;
415*7d8d0e25Snbeams   for (int i = P1D; i < Q1D; i++)
416*7d8d0e25Snbeams     r_U[i] = 0.0;
417*7d8d0e25Snbeams }
418*7d8d0e25Snbeams 
419*7d8d0e25Snbeams //------------------------------------------------------------------------------
420*7d8d0e25Snbeams // Write DoFs
421*7d8d0e25Snbeams //------------------------------------------------------------------------------
422*7d8d0e25Snbeams inline __device__ void writeDofs3d(const int elem, const int tidx,
423*7d8d0e25Snbeams                                    const int tidy, const int comp,
424*7d8d0e25Snbeams                                    const int nelem, const CeedScalar *r_V,
425*7d8d0e25Snbeams                                    CeedScalar *d_V) {
426*7d8d0e25Snbeams   if (tidx < P1D && tidy < P1D) {
427*7d8d0e25Snbeams     for (int i = 0; i < P1D; i++)
428*7d8d0e25Snbeams       d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
429*7d8d0e25Snbeams           comp*P1D*P1D*P1D*nelem] = r_V[i];
430*7d8d0e25Snbeams   }
431*7d8d0e25Snbeams }
432*7d8d0e25Snbeams 
433*7d8d0e25Snbeams //------------------------------------------------------------------------------
434*7d8d0e25Snbeams // Read quadrature point data
435*7d8d0e25Snbeams //------------------------------------------------------------------------------
436*7d8d0e25Snbeams inline __device__ void readQuads3d(const int elem, const int tidx,
437*7d8d0e25Snbeams                                    const int tidy, const int comp,
438*7d8d0e25Snbeams                                    const int dim, const int nelem,
439*7d8d0e25Snbeams                                    const CeedScalar *d_U, CeedScalar *r_U) {
440*7d8d0e25Snbeams   for (int i = 0; i < Q1D; i++)
441*7d8d0e25Snbeams     r_U[i] = (tidx < Q1D && tidy < Q1D) ?
442*7d8d0e25Snbeams               d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
443*7d8d0e25Snbeams               comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0;
444*7d8d0e25Snbeams   for (int i = Q1D; i < P1D; i++)
445*7d8d0e25Snbeams     r_U[i] = 0.0;
446*7d8d0e25Snbeams }
447*7d8d0e25Snbeams 
448*7d8d0e25Snbeams //------------------------------------------------------------------------------
449*7d8d0e25Snbeams // Write quadrature point data
450*7d8d0e25Snbeams //------------------------------------------------------------------------------
451*7d8d0e25Snbeams inline __device__ void writeQuads3d(const int elem, const int tidx,
452*7d8d0e25Snbeams                                     const int tidy, const int comp,
453*7d8d0e25Snbeams                                     const int dim, const int nelem,
454*7d8d0e25Snbeams                                     const CeedScalar *r_V, CeedScalar *d_V) {
455*7d8d0e25Snbeams   if (tidx < Q1D && tidy < Q1D) {
456*7d8d0e25Snbeams     for (int i = 0; i < Q1D; i++)
457*7d8d0e25Snbeams       d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
458*7d8d0e25Snbeams           dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i];
459*7d8d0e25Snbeams   }
460*7d8d0e25Snbeams }
461*7d8d0e25Snbeams 
462*7d8d0e25Snbeams //------------------------------------------------------------------------------
463*7d8d0e25Snbeams // 3D tensor contract x
464*7d8d0e25Snbeams //------------------------------------------------------------------------------
465*7d8d0e25Snbeams inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
466*7d8d0e25Snbeams                                    const int tidy, const int tidz,
467*7d8d0e25Snbeams                                    const CeedScalar *U,
468*7d8d0e25Snbeams                                    const CeedScalar *B,
469*7d8d0e25Snbeams                                    CeedScalar *V) {
470*7d8d0e25Snbeams   for (int k = 0; k < P1D; ++k) {
471*7d8d0e25Snbeams     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
472*7d8d0e25Snbeams     __syncthreads();
473*7d8d0e25Snbeams     V[k] = 0.0;
474*7d8d0e25Snbeams     if (tidx < Q1D && tidy < P1D)
475*7d8d0e25Snbeams       for (int i = 0; i < P1D; ++i)
476*7d8d0e25Snbeams         V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
477*7d8d0e25Snbeams     __syncthreads();
478*7d8d0e25Snbeams   }
479*7d8d0e25Snbeams }
480*7d8d0e25Snbeams 
481*7d8d0e25Snbeams //------------------------------------------------------------------------------
482*7d8d0e25Snbeams // 3D tensor contract y
483*7d8d0e25Snbeams //------------------------------------------------------------------------------
484*7d8d0e25Snbeams inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
485*7d8d0e25Snbeams                                    const int tidy, const int tidz,
486*7d8d0e25Snbeams                                    const CeedScalar *U,
487*7d8d0e25Snbeams                                    const CeedScalar *B,
488*7d8d0e25Snbeams                                    CeedScalar *V) {
489*7d8d0e25Snbeams   for (int k = 0; k < P1D; ++k) {
490*7d8d0e25Snbeams     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
491*7d8d0e25Snbeams     __syncthreads();
492*7d8d0e25Snbeams     V[k] = 0.0;
493*7d8d0e25Snbeams     if (tidx < Q1D && tidy < Q1D)
494*7d8d0e25Snbeams       for (int i = 0; i < P1D; ++i)
495*7d8d0e25Snbeams         V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
496*7d8d0e25Snbeams     __syncthreads();
497*7d8d0e25Snbeams   }
498*7d8d0e25Snbeams }
499*7d8d0e25Snbeams 
500*7d8d0e25Snbeams //------------------------------------------------------------------------------
501*7d8d0e25Snbeams // 3D tensor contract z
502*7d8d0e25Snbeams //------------------------------------------------------------------------------
503*7d8d0e25Snbeams inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
504*7d8d0e25Snbeams                                    const int tidy, const int tidz,
505*7d8d0e25Snbeams                                    const CeedScalar *U,
506*7d8d0e25Snbeams                                    const CeedScalar *B,
507*7d8d0e25Snbeams                                    CeedScalar *V) {
508*7d8d0e25Snbeams   for (int k = 0; k < Q1D; ++k) {
509*7d8d0e25Snbeams     V[k] = 0.0;
510*7d8d0e25Snbeams     if (tidx < Q1D && tidy < Q1D)
511*7d8d0e25Snbeams       for (int i = 0; i < P1D; ++i)
512*7d8d0e25Snbeams         V[k] += B[i + k*P1D] * U[i]; // Contract z direction
513*7d8d0e25Snbeams   }
514*7d8d0e25Snbeams   for (int k = Q1D; k < P1D; ++k)
515*7d8d0e25Snbeams     V[k] = 0.0;
516*7d8d0e25Snbeams }
517*7d8d0e25Snbeams 
518*7d8d0e25Snbeams //------------------------------------------------------------------------------
519*7d8d0e25Snbeams // 3D transpose tensor contract z
520*7d8d0e25Snbeams //------------------------------------------------------------------------------
521*7d8d0e25Snbeams inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
522*7d8d0e25Snbeams                                             const int tidy, const int tidz,
523*7d8d0e25Snbeams                                             const CeedScalar *U,
524*7d8d0e25Snbeams                                             const CeedScalar *B,
525*7d8d0e25Snbeams                                             CeedScalar *V) {
526*7d8d0e25Snbeams   for (int k = 0; k < P1D; ++k) {
527*7d8d0e25Snbeams     V[k] = 0.0;
528*7d8d0e25Snbeams     if (tidx < Q1D && tidy < Q1D)
529*7d8d0e25Snbeams       for (int i = 0; i < Q1D; ++i)
530*7d8d0e25Snbeams         V[k] += B[k + i*P1D] * U[i]; // Contract z direction
531*7d8d0e25Snbeams   }
532*7d8d0e25Snbeams   for (int k = P1D; k < Q1D; ++k)
533*7d8d0e25Snbeams     V[k] = 0.0;
534*7d8d0e25Snbeams }
535*7d8d0e25Snbeams 
536*7d8d0e25Snbeams //------------------------------------------------------------------------------
537*7d8d0e25Snbeams // 3D transpose tensor contract y
538*7d8d0e25Snbeams //------------------------------------------------------------------------------
539*7d8d0e25Snbeams inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
540*7d8d0e25Snbeams                                             const int tidy, const int tidz,
541*7d8d0e25Snbeams                                             const CeedScalar *U,
542*7d8d0e25Snbeams                                             const CeedScalar *B,
543*7d8d0e25Snbeams                                             CeedScalar *V) {
544*7d8d0e25Snbeams   for (int k = 0; k < P1D; ++k) {
545*7d8d0e25Snbeams     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
546*7d8d0e25Snbeams     __syncthreads();
547*7d8d0e25Snbeams     V[k] = 0.0;
548*7d8d0e25Snbeams     if (tidx < Q1D && tidy < P1D)
549*7d8d0e25Snbeams       for (int i = 0; i < Q1D; ++i)
550*7d8d0e25Snbeams         V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
551*7d8d0e25Snbeams     __syncthreads();
552*7d8d0e25Snbeams   }
553*7d8d0e25Snbeams }
554*7d8d0e25Snbeams 
555*7d8d0e25Snbeams //------------------------------------------------------------------------------
556*7d8d0e25Snbeams // 3D transpose tensor contract x
557*7d8d0e25Snbeams //------------------------------------------------------------------------------
558*7d8d0e25Snbeams inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
559*7d8d0e25Snbeams                                             const int tidy, const int tidz,
560*7d8d0e25Snbeams                                             const CeedScalar *U,
561*7d8d0e25Snbeams                                             const CeedScalar *B,
562*7d8d0e25Snbeams                                             CeedScalar *V) {
563*7d8d0e25Snbeams   for (int k = 0; k < P1D; ++k) {
564*7d8d0e25Snbeams     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
565*7d8d0e25Snbeams     __syncthreads();
566*7d8d0e25Snbeams     V[k] = 0.0;
567*7d8d0e25Snbeams     if (tidx < P1D && tidy < P1D)
568*7d8d0e25Snbeams       for (int i = 0; i < Q1D; ++i)
569*7d8d0e25Snbeams         V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
570*7d8d0e25Snbeams     __syncthreads();
571*7d8d0e25Snbeams   }
572*7d8d0e25Snbeams }
573*7d8d0e25Snbeams 
574*7d8d0e25Snbeams //------------------------------------------------------------------------------
575*7d8d0e25Snbeams // 3D interpolate to quadrature points
576*7d8d0e25Snbeams //------------------------------------------------------------------------------
577*7d8d0e25Snbeams inline __device__ void interp3d(const CeedInt nelem, const int transpose,
578*7d8d0e25Snbeams                                 const CeedScalar *c_B,
579*7d8d0e25Snbeams                                 const CeedScalar *__restrict__ d_U,
580*7d8d0e25Snbeams                                 CeedScalar *__restrict__ d_V,
581*7d8d0e25Snbeams                                 CeedScalar *slice) {
582*7d8d0e25Snbeams   CeedScalar r_V[T1D];
583*7d8d0e25Snbeams   CeedScalar r_t[T1D];
584*7d8d0e25Snbeams 
585*7d8d0e25Snbeams   const int tidx = threadIdx.x;
586*7d8d0e25Snbeams   const int tidy = threadIdx.y;
587*7d8d0e25Snbeams   const int tidz = threadIdx.z;
588*7d8d0e25Snbeams   const int blockElem = tidz/BASIS_NCOMP;
589*7d8d0e25Snbeams   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
590*7d8d0e25Snbeams   const int comp = tidz%BASIS_NCOMP;
591*7d8d0e25Snbeams 
592*7d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
593*7d8d0e25Snbeams        elem += gridDim.x*elemsPerBlock) {
594*7d8d0e25Snbeams     for (int i = 0; i < T1D; ++i) {
595*7d8d0e25Snbeams       r_V[i] = 0.0;
596*7d8d0e25Snbeams       r_t[i] = 0.0;
597*7d8d0e25Snbeams     }
598*7d8d0e25Snbeams     if (!transpose) {
599*7d8d0e25Snbeams       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
600*7d8d0e25Snbeams       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
601*7d8d0e25Snbeams       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
602*7d8d0e25Snbeams       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
603*7d8d0e25Snbeams       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
604*7d8d0e25Snbeams     } else {
605*7d8d0e25Snbeams       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
606*7d8d0e25Snbeams       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
607*7d8d0e25Snbeams       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
608*7d8d0e25Snbeams       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
609*7d8d0e25Snbeams       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
610*7d8d0e25Snbeams     }
611*7d8d0e25Snbeams   }
612*7d8d0e25Snbeams }
613*7d8d0e25Snbeams 
614*7d8d0e25Snbeams //------------------------------------------------------------------------------
615*7d8d0e25Snbeams // 3D derivatives at quadrature points
616*7d8d0e25Snbeams //------------------------------------------------------------------------------
617*7d8d0e25Snbeams inline __device__ void grad3d(const CeedInt nelem, const int transpose,
618*7d8d0e25Snbeams                               const CeedScalar *c_B, const CeedScalar *c_G,
619*7d8d0e25Snbeams                               const CeedScalar *__restrict__ d_U,
620*7d8d0e25Snbeams                               CeedScalar *__restrict__ d_V,
621*7d8d0e25Snbeams                               CeedScalar *slice) {
622*7d8d0e25Snbeams   // Use P1D for one of these
623*7d8d0e25Snbeams   CeedScalar r_U[T1D];
624*7d8d0e25Snbeams   CeedScalar r_V[T1D];
625*7d8d0e25Snbeams   CeedScalar r_t[T1D];
626*7d8d0e25Snbeams 
627*7d8d0e25Snbeams   const int tidx = threadIdx.x;
628*7d8d0e25Snbeams   const int tidy = threadIdx.y;
629*7d8d0e25Snbeams   const int tidz = threadIdx.z;
630*7d8d0e25Snbeams   const int blockElem = tidz/BASIS_NCOMP;
631*7d8d0e25Snbeams   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
632*7d8d0e25Snbeams   const int comp = tidz%BASIS_NCOMP;
633*7d8d0e25Snbeams   int dim;
634*7d8d0e25Snbeams 
635*7d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
636*7d8d0e25Snbeams        elem += gridDim.x*elemsPerBlock) {
637*7d8d0e25Snbeams     for (int i = 0; i < T1D; ++i) {
638*7d8d0e25Snbeams       r_U[i] = 0.0;
639*7d8d0e25Snbeams       r_V[i] = 0.0;
640*7d8d0e25Snbeams       r_t[i] = 0.0;
641*7d8d0e25Snbeams     }
642*7d8d0e25Snbeams     if (!transpose) {
643*7d8d0e25Snbeams       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
644*7d8d0e25Snbeams       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
645*7d8d0e25Snbeams       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
646*7d8d0e25Snbeams       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
647*7d8d0e25Snbeams       dim = 0;
648*7d8d0e25Snbeams       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
649*7d8d0e25Snbeams       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
650*7d8d0e25Snbeams       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
651*7d8d0e25Snbeams       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
652*7d8d0e25Snbeams       dim = 1;
653*7d8d0e25Snbeams       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
654*7d8d0e25Snbeams       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
655*7d8d0e25Snbeams       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
656*7d8d0e25Snbeams       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
657*7d8d0e25Snbeams       dim = 2;
658*7d8d0e25Snbeams       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
659*7d8d0e25Snbeams     } else {
660*7d8d0e25Snbeams       dim = 0;
661*7d8d0e25Snbeams       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
662*7d8d0e25Snbeams       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
663*7d8d0e25Snbeams       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
664*7d8d0e25Snbeams       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
665*7d8d0e25Snbeams       dim = 1;
666*7d8d0e25Snbeams       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
667*7d8d0e25Snbeams       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
668*7d8d0e25Snbeams       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
669*7d8d0e25Snbeams       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
670*7d8d0e25Snbeams       add(r_V, r_t);
671*7d8d0e25Snbeams       dim = 2;
672*7d8d0e25Snbeams       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
673*7d8d0e25Snbeams       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
674*7d8d0e25Snbeams       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
675*7d8d0e25Snbeams       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
676*7d8d0e25Snbeams       add(r_V, r_t);
677*7d8d0e25Snbeams       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
678*7d8d0e25Snbeams     }
679*7d8d0e25Snbeams   }
680*7d8d0e25Snbeams }
681*7d8d0e25Snbeams 
682*7d8d0e25Snbeams //------------------------------------------------------------------------------
683*7d8d0e25Snbeams // 3D quadrature weights
684*7d8d0e25Snbeams //------------------------------------------------------------------------------
685*7d8d0e25Snbeams __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
686*7d8d0e25Snbeams                          CeedScalar *w) {
687*7d8d0e25Snbeams   const int i = threadIdx.x;
688*7d8d0e25Snbeams   const int j = threadIdx.y;
689*7d8d0e25Snbeams   const int k = threadIdx.z;
690*7d8d0e25Snbeams   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
691*7d8d0e25Snbeams   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
692*7d8d0e25Snbeams     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
693*7d8d0e25Snbeams     w[ind] = weight;
694*7d8d0e25Snbeams   }
695*7d8d0e25Snbeams }
696*7d8d0e25Snbeams 
697*7d8d0e25Snbeams 
698*7d8d0e25Snbeams //------------------------------------------------------------------------------
699*7d8d0e25Snbeams // Basis kernels
700*7d8d0e25Snbeams //------------------------------------------------------------------------------
701*7d8d0e25Snbeams 
702*7d8d0e25Snbeams //------------------------------------------------------------------------------
703*7d8d0e25Snbeams // Interp kernel by dim
704*7d8d0e25Snbeams //------------------------------------------------------------------------------
705*7d8d0e25Snbeams extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
706*7d8d0e25Snbeams                                   const CeedScalar *c_B,
707*7d8d0e25Snbeams                                   const CeedScalar *__restrict__ d_U,
708*7d8d0e25Snbeams                                   CeedScalar *__restrict__ d_V) {
709*7d8d0e25Snbeams   HIP_DYNAMIC_SHARED( double, slice)
710*7d8d0e25Snbeams   if (BASIS_DIM == 1) {
711*7d8d0e25Snbeams     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
712*7d8d0e25Snbeams   } else if (BASIS_DIM == 2) {
713*7d8d0e25Snbeams     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
714*7d8d0e25Snbeams   } else if (BASIS_DIM == 3) {
715*7d8d0e25Snbeams     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
716*7d8d0e25Snbeams   }
717*7d8d0e25Snbeams }
718*7d8d0e25Snbeams 
719*7d8d0e25Snbeams //------------------------------------------------------------------------------
720*7d8d0e25Snbeams // Grad kernel by dim
721*7d8d0e25Snbeams //------------------------------------------------------------------------------
722*7d8d0e25Snbeams extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
723*7d8d0e25Snbeams                                 const CeedScalar *c_B, const CeedScalar *c_G,
724*7d8d0e25Snbeams                                 const CeedScalar *__restrict__ d_U,
725*7d8d0e25Snbeams                                 CeedScalar *__restrict__ d_V) {
726*7d8d0e25Snbeams   HIP_DYNAMIC_SHARED( double, slice)
727*7d8d0e25Snbeams   if (BASIS_DIM == 1) {
728*7d8d0e25Snbeams     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
729*7d8d0e25Snbeams   } else if (BASIS_DIM == 2) {
730*7d8d0e25Snbeams     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
731*7d8d0e25Snbeams   } else if (BASIS_DIM == 3) {
732*7d8d0e25Snbeams     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
733*7d8d0e25Snbeams   }
734*7d8d0e25Snbeams }
735*7d8d0e25Snbeams 
736*7d8d0e25Snbeams //------------------------------------------------------------------------------
737*7d8d0e25Snbeams // Weight kernels by dim
738*7d8d0e25Snbeams //------------------------------------------------------------------------------
739*7d8d0e25Snbeams extern "C" __global__ void weight(const CeedInt nelem,
740*7d8d0e25Snbeams                                   const CeedScalar *__restrict__ qweight1d,
741*7d8d0e25Snbeams                                   CeedScalar *__restrict__ v) {
742*7d8d0e25Snbeams   if (BASIS_DIM == 1) {
743*7d8d0e25Snbeams     weight1d(nelem, qweight1d, v);
744*7d8d0e25Snbeams   } else if (BASIS_DIM == 2) {
745*7d8d0e25Snbeams     weight2d(nelem, qweight1d, v);
746*7d8d0e25Snbeams   } else if (BASIS_DIM == 3) {
747*7d8d0e25Snbeams     weight3d(nelem, qweight1d, v);
748*7d8d0e25Snbeams   }
749*7d8d0e25Snbeams }
750*7d8d0e25Snbeams 
751*7d8d0e25Snbeams );
752*7d8d0e25Snbeams // *INDENT-ON*
753*7d8d0e25Snbeams 
754*7d8d0e25Snbeams //------------------------------------------------------------------------------
755*7d8d0e25Snbeams // Device initalization
756*7d8d0e25Snbeams //------------------------------------------------------------------------------
757*7d8d0e25Snbeams int CeedHipInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
758*7d8d0e25Snbeams                       CeedScalar **c_B);
759*7d8d0e25Snbeams int CeedHipInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
760*7d8d0e25Snbeams                           CeedInt Q1d, CeedScalar **c_B_ptr,
761*7d8d0e25Snbeams                           CeedScalar **c_G_ptr);
762*7d8d0e25Snbeams 
763*7d8d0e25Snbeams //------------------------------------------------------------------------------
764*7d8d0e25Snbeams // Apply basis
765*7d8d0e25Snbeams //------------------------------------------------------------------------------
766*7d8d0e25Snbeams int CeedBasisApplyTensor_Hip_shared(CeedBasis basis, const CeedInt nelem,
767*7d8d0e25Snbeams                                     CeedTransposeMode tmode,
768*7d8d0e25Snbeams                                     CeedEvalMode emode, CeedVector u,
769*7d8d0e25Snbeams                                     CeedVector v) {
770*7d8d0e25Snbeams   int ierr;
771*7d8d0e25Snbeams   Ceed ceed;
772*7d8d0e25Snbeams   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
773*7d8d0e25Snbeams   Ceed_Hip_shared *ceed_Hip;
774*7d8d0e25Snbeams   CeedGetData(ceed, &ceed_Hip); CeedChk(ierr);
775*7d8d0e25Snbeams   CeedBasis_Hip_shared *data;
776*7d8d0e25Snbeams   CeedBasisGetData(basis, &data); CeedChk(ierr);
777*7d8d0e25Snbeams   const CeedInt transpose = tmode == CEED_TRANSPOSE;
778*7d8d0e25Snbeams   CeedInt dim, ncomp;
779*7d8d0e25Snbeams   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
780*7d8d0e25Snbeams   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
781*7d8d0e25Snbeams 
782*7d8d0e25Snbeams   // Read vectors
783*7d8d0e25Snbeams   const CeedScalar *d_u;
784*7d8d0e25Snbeams   CeedScalar *d_v;
785*7d8d0e25Snbeams   if (emode != CEED_EVAL_WEIGHT) {
786*7d8d0e25Snbeams     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
787*7d8d0e25Snbeams   }
788*7d8d0e25Snbeams   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
789*7d8d0e25Snbeams 
790*7d8d0e25Snbeams   // Clear v for transpose mode
791*7d8d0e25Snbeams   if (tmode == CEED_TRANSPOSE) {
792*7d8d0e25Snbeams     CeedInt length;
793*7d8d0e25Snbeams     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
794*7d8d0e25Snbeams     ierr = hipMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
795*7d8d0e25Snbeams   }
796*7d8d0e25Snbeams 
797*7d8d0e25Snbeams   // Apply basis operation
798*7d8d0e25Snbeams   switch (emode) {
799*7d8d0e25Snbeams   case CEED_EVAL_INTERP: {
800*7d8d0e25Snbeams     CeedInt P1d, Q1d;
801*7d8d0e25Snbeams     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
802*7d8d0e25Snbeams     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
803*7d8d0e25Snbeams     CeedInt thread1d = CeedIntMax(Q1d, P1d);
804*7d8d0e25Snbeams     ierr = CeedHipInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
805*7d8d0e25Snbeams     CeedChk(ierr);
806*7d8d0e25Snbeams     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
807*7d8d0e25Snbeams                           &d_u, &d_v
808*7d8d0e25Snbeams                          };
809*7d8d0e25Snbeams     if (dim == 1) {
810*7d8d0e25Snbeams       CeedInt elemsPerBlock = 32*thread1d > 256? 256/thread1d : 32;
811*7d8d0e25Snbeams       elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1;
812*7d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
813*7d8d0e25Snbeams                                              ? 1 : 0 );
814*7d8d0e25Snbeams       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
815*7d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, 1,
816*7d8d0e25Snbeams                                        elemsPerBlock, sharedMem,
817*7d8d0e25Snbeams                                        interpargs); CeedChk(ierr);
818*7d8d0e25Snbeams     } else if (dim == 2) {
819*7d8d0e25Snbeams       const CeedInt optElems[7] = {0,32,8,6,4,2,6};
820*7d8d0e25Snbeams       // elemsPerBlock must be at least 1
821*7d8d0e25Snbeams       CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1);
822*7d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
823*7d8d0e25Snbeams                                              ? 1 : 0 );
824*7d8d0e25Snbeams       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
825*7d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d,
826*7d8d0e25Snbeams                                        ncomp*elemsPerBlock, sharedMem,
827*7d8d0e25Snbeams                                        interpargs); CeedChk(ierr);
828*7d8d0e25Snbeams     } else if (dim == 3) {
829*7d8d0e25Snbeams       CeedInt elemsPerBlock = 1;
830*7d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
831*7d8d0e25Snbeams                                              ? 1 : 0 );
832*7d8d0e25Snbeams       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
833*7d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d,
834*7d8d0e25Snbeams                                        ncomp*elemsPerBlock, sharedMem,
835*7d8d0e25Snbeams                                        interpargs); CeedChk(ierr);
836*7d8d0e25Snbeams     }
837*7d8d0e25Snbeams   } break;
838*7d8d0e25Snbeams   case CEED_EVAL_GRAD: {
839*7d8d0e25Snbeams     CeedInt P1d, Q1d;
840*7d8d0e25Snbeams     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
841*7d8d0e25Snbeams     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
842*7d8d0e25Snbeams     CeedInt thread1d = CeedIntMax(Q1d, P1d);
843*7d8d0e25Snbeams     ierr = CeedHipInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
844*7d8d0e25Snbeams                                  Q1d, &data->c_B, &data->c_G);
845*7d8d0e25Snbeams     CeedChk(ierr);
846*7d8d0e25Snbeams     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
847*7d8d0e25Snbeams                         &data->c_G, &d_u, &d_v
848*7d8d0e25Snbeams                        };
849*7d8d0e25Snbeams     if (dim == 1) {
850*7d8d0e25Snbeams       CeedInt elemsPerBlock = 32*thread1d > 256? 256/thread1d : 32;
851*7d8d0e25Snbeams       elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1;
852*7d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
853*7d8d0e25Snbeams                                              ? 1 : 0 );
854*7d8d0e25Snbeams       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
855*7d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, 1,
856*7d8d0e25Snbeams                                        elemsPerBlock, sharedMem, gradargs);
857*7d8d0e25Snbeams       CeedChk(ierr);
858*7d8d0e25Snbeams     } else if (dim == 2) {
859*7d8d0e25Snbeams       const CeedInt optElems[7] = {0,32,8,6,4,2,6};
860*7d8d0e25Snbeams       // elemsPerBlock must be at least 1
861*7d8d0e25Snbeams       CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1);
862*7d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
863*7d8d0e25Snbeams                                              ? 1 : 0 );
864*7d8d0e25Snbeams       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
865*7d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d,
866*7d8d0e25Snbeams                                        ncomp*elemsPerBlock, sharedMem,
867*7d8d0e25Snbeams                                        gradargs); CeedChk(ierr);
868*7d8d0e25Snbeams     } else if (dim == 3) {
869*7d8d0e25Snbeams       CeedInt elemsPerBlock = 1;
870*7d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
871*7d8d0e25Snbeams                                              ? 1 : 0 );
872*7d8d0e25Snbeams       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
873*7d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d,
874*7d8d0e25Snbeams                                        ncomp*elemsPerBlock, sharedMem,
875*7d8d0e25Snbeams                                        gradargs); CeedChk(ierr);
876*7d8d0e25Snbeams     }
877*7d8d0e25Snbeams   } break;
878*7d8d0e25Snbeams   case CEED_EVAL_WEIGHT: {
879*7d8d0e25Snbeams     CeedInt Q1d;
880*7d8d0e25Snbeams     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
881*7d8d0e25Snbeams     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
882*7d8d0e25Snbeams     if (dim == 1) {
883*7d8d0e25Snbeams       const CeedInt optElems = 32/Q1d;
884*7d8d0e25Snbeams       const CeedInt elemsPerBlock = optElems>0?optElems:1;
885*7d8d0e25Snbeams       const CeedInt gridsize = nelem/elemsPerBlock + ( (
886*7d8d0e25Snbeams                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
887*7d8d0e25Snbeams       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d,
888*7d8d0e25Snbeams                                  elemsPerBlock, 1, weightargs);
889*7d8d0e25Snbeams       CeedChk(ierr);
890*7d8d0e25Snbeams     } else if (dim == 2) {
891*7d8d0e25Snbeams       const CeedInt optElems = 32/(Q1d*Q1d);
892*7d8d0e25Snbeams       const CeedInt elemsPerBlock = optElems>0?optElems:1;
893*7d8d0e25Snbeams       const CeedInt gridsize = nelem/elemsPerBlock + ( (
894*7d8d0e25Snbeams                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
895*7d8d0e25Snbeams       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d,
896*7d8d0e25Snbeams                                  elemsPerBlock, weightargs);
897*7d8d0e25Snbeams       CeedChk(ierr);
898*7d8d0e25Snbeams     } else if (dim == 3) {
899*7d8d0e25Snbeams       const CeedInt gridsize = nelem;
900*7d8d0e25Snbeams       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
901*7d8d0e25Snbeams                                  weightargs);
902*7d8d0e25Snbeams       CeedChk(ierr);
903*7d8d0e25Snbeams     }
904*7d8d0e25Snbeams   } break;
905*7d8d0e25Snbeams   // LCOV_EXCL_START
906*7d8d0e25Snbeams   // Evaluate the divergence to/from the quadrature points
907*7d8d0e25Snbeams   case CEED_EVAL_DIV:
908*7d8d0e25Snbeams     return CeedError(ceed, 1, "CEED_EVAL_DIV not supported");
909*7d8d0e25Snbeams   // Evaluate the curl to/from the quadrature points
910*7d8d0e25Snbeams   case CEED_EVAL_CURL:
911*7d8d0e25Snbeams     return CeedError(ceed, 1, "CEED_EVAL_CURL not supported");
912*7d8d0e25Snbeams   // Take no action, BasisApply should not have been called
913*7d8d0e25Snbeams   case CEED_EVAL_NONE:
914*7d8d0e25Snbeams     return CeedError(ceed, 1,
915*7d8d0e25Snbeams                      "CEED_EVAL_NONE does not make sense in this context");
916*7d8d0e25Snbeams     // LCOV_EXCL_STOP
917*7d8d0e25Snbeams   }
918*7d8d0e25Snbeams 
919*7d8d0e25Snbeams   // Restore vectors
920*7d8d0e25Snbeams   if (emode != CEED_EVAL_WEIGHT) {
921*7d8d0e25Snbeams     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
922*7d8d0e25Snbeams   }
923*7d8d0e25Snbeams   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
924*7d8d0e25Snbeams   return 0;
925*7d8d0e25Snbeams }
926*7d8d0e25Snbeams 
927*7d8d0e25Snbeams //------------------------------------------------------------------------------
928*7d8d0e25Snbeams // Destroy basis
929*7d8d0e25Snbeams //------------------------------------------------------------------------------
930*7d8d0e25Snbeams static int CeedBasisDestroy_Hip_shared(CeedBasis basis) {
931*7d8d0e25Snbeams   int ierr;
932*7d8d0e25Snbeams   Ceed ceed;
933*7d8d0e25Snbeams   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
934*7d8d0e25Snbeams 
935*7d8d0e25Snbeams   CeedBasis_Hip_shared *data;
936*7d8d0e25Snbeams   ierr = CeedBasisGetData(basis, &data); CeedChk(ierr);
937*7d8d0e25Snbeams 
938*7d8d0e25Snbeams   CeedChk_Hip(ceed, hipModuleUnload(data->module));
939*7d8d0e25Snbeams 
940*7d8d0e25Snbeams   ierr = hipFree(data->d_qweight1d); CeedChk_Hip(ceed, ierr);
941*7d8d0e25Snbeams   ierr = hipFree(data->d_interp1d); CeedChk_Hip(ceed, ierr);
942*7d8d0e25Snbeams   ierr = hipFree(data->d_grad1d); CeedChk_Hip(ceed, ierr);
943*7d8d0e25Snbeams   ierr = hipFree(data->d_collograd1d); CeedChk_Hip(ceed, ierr);
944*7d8d0e25Snbeams 
945*7d8d0e25Snbeams   ierr = CeedFree(&data); CeedChk(ierr);
946*7d8d0e25Snbeams 
947*7d8d0e25Snbeams   return 0;
948*7d8d0e25Snbeams }
949*7d8d0e25Snbeams 
950*7d8d0e25Snbeams //------------------------------------------------------------------------------
951*7d8d0e25Snbeams // Create tensor basis
952*7d8d0e25Snbeams //------------------------------------------------------------------------------
953*7d8d0e25Snbeams int CeedBasisCreateTensorH1_Hip_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
954*7d8d0e25Snbeams                                        const CeedScalar *interp1d,
955*7d8d0e25Snbeams                                        const CeedScalar *grad1d,
956*7d8d0e25Snbeams                                        const CeedScalar *qref1d,
957*7d8d0e25Snbeams                                        const CeedScalar *qweight1d,
958*7d8d0e25Snbeams                                        CeedBasis basis) {
959*7d8d0e25Snbeams   int ierr;
960*7d8d0e25Snbeams   Ceed ceed;
961*7d8d0e25Snbeams   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
962*7d8d0e25Snbeams   CeedBasis_Hip_shared *data;
963*7d8d0e25Snbeams   ierr = CeedCalloc(1, &data); CeedChk(ierr);
964*7d8d0e25Snbeams 
965*7d8d0e25Snbeams   // Copy basis data to GPU
966*7d8d0e25Snbeams   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
967*7d8d0e25Snbeams   ierr = hipMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Hip(ceed, ierr);
968*7d8d0e25Snbeams   ierr = hipMemcpy(data->d_qweight1d, qweight1d, qBytes,
969*7d8d0e25Snbeams                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
970*7d8d0e25Snbeams 
971*7d8d0e25Snbeams   const CeedInt iBytes = qBytes * P1d;
972*7d8d0e25Snbeams   ierr = hipMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Hip(ceed, ierr);
973*7d8d0e25Snbeams   ierr = hipMemcpy(data->d_interp1d, interp1d, iBytes,
974*7d8d0e25Snbeams                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
975*7d8d0e25Snbeams 
976*7d8d0e25Snbeams   ierr = hipMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Hip(ceed, ierr);
977*7d8d0e25Snbeams   ierr = hipMemcpy(data->d_grad1d, grad1d, iBytes,
978*7d8d0e25Snbeams                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
979*7d8d0e25Snbeams 
980*7d8d0e25Snbeams   // Compute collocated gradient and copy to GPU
981*7d8d0e25Snbeams   data->d_collograd1d = NULL;
982*7d8d0e25Snbeams   if (dim == 3 && Q1d >= P1d) {
983*7d8d0e25Snbeams     CeedScalar *collograd1d;
984*7d8d0e25Snbeams     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChk(ierr);
985*7d8d0e25Snbeams     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChk(ierr);
986*7d8d0e25Snbeams     ierr = hipMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
987*7d8d0e25Snbeams     CeedChk_Hip(ceed, ierr);
988*7d8d0e25Snbeams     ierr = hipMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
989*7d8d0e25Snbeams                      hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
990*7d8d0e25Snbeams     ierr = CeedFree(&collograd1d); CeedChk(ierr);
991*7d8d0e25Snbeams   }
992*7d8d0e25Snbeams 
993*7d8d0e25Snbeams   // Compile basis kernels
994*7d8d0e25Snbeams   CeedInt ncomp;
995*7d8d0e25Snbeams   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
996*7d8d0e25Snbeams   ierr = CeedCompileHip(ceed, kernelsShared, &data->module, 8,
997*7d8d0e25Snbeams                         "Q1D", Q1d,
998*7d8d0e25Snbeams                         "P1D", P1d,
999*7d8d0e25Snbeams                         "T1D", CeedIntMax(Q1d, P1d),
1000*7d8d0e25Snbeams                         "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
1001*7d8d0e25Snbeams                             Q1d : P1d, dim),
1002*7d8d0e25Snbeams                         "BASIS_DIM", dim,
1003*7d8d0e25Snbeams                         "BASIS_NCOMP", ncomp,
1004*7d8d0e25Snbeams                         "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
1005*7d8d0e25Snbeams                         "BASIS_NQPT", CeedIntPow(Q1d, dim)
1006*7d8d0e25Snbeams                        ); CeedChk(ierr);
1007*7d8d0e25Snbeams   ierr = CeedGetKernelHip(ceed, data->module, "interp", &data->interp);
1008*7d8d0e25Snbeams   CeedChk(ierr);
1009*7d8d0e25Snbeams   ierr = CeedGetKernelHip(ceed, data->module, "grad", &data->grad);
1010*7d8d0e25Snbeams   CeedChk(ierr);
1011*7d8d0e25Snbeams   ierr = CeedGetKernelHip(ceed, data->module, "weight", &data->weight);
1012*7d8d0e25Snbeams   CeedChk(ierr);
1013*7d8d0e25Snbeams 
1014*7d8d0e25Snbeams   ierr = CeedBasisSetData(basis, data); CeedChk(ierr);
1015*7d8d0e25Snbeams 
1016*7d8d0e25Snbeams   // Register backend functions
1017*7d8d0e25Snbeams   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
1018*7d8d0e25Snbeams                                 CeedBasisApplyTensor_Hip_shared);
1019*7d8d0e25Snbeams   CeedChk(ierr);
1020*7d8d0e25Snbeams   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
1021*7d8d0e25Snbeams                                 CeedBasisDestroy_Hip_shared); CeedChk(ierr);
1022*7d8d0e25Snbeams   return 0;
1023*7d8d0e25Snbeams }
1024*7d8d0e25Snbeams //------------------------------------------------------------------------------
1025