xref: /libCEED/rust/libceed-sys/c-src/backends/hip-shared/ceed-hip-shared-basis.c (revision 9e31c45b1261f08aabe3aa77b1bb7169e37ea122)
17d8d0e25Snbeams // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
27d8d0e25Snbeams // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
37d8d0e25Snbeams // All Rights reserved. See files LICENSE and NOTICE for details.
47d8d0e25Snbeams //
57d8d0e25Snbeams // This file is part of CEED, a collection of benchmarks, miniapps, software
67d8d0e25Snbeams // libraries and APIs for efficient high-order finite element and spectral
77d8d0e25Snbeams // element discretizations for exascale applications. For more information and
87d8d0e25Snbeams // source code availability see http://github.com/ceed.
97d8d0e25Snbeams //
107d8d0e25Snbeams // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
117d8d0e25Snbeams // a collaborative effort of two U.S. Department of Energy organizations (Office
127d8d0e25Snbeams // of Science and the National Nuclear Security Administration) responsible for
137d8d0e25Snbeams // the planning and preparation of a capable exascale ecosystem, including
147d8d0e25Snbeams // software, applications, hardware, advanced system engineering and early
157d8d0e25Snbeams // testbed platforms, in support of the nation's exascale computing imperative.
167d8d0e25Snbeams 
177d8d0e25Snbeams #include "ceed-hip-shared.h"
187d8d0e25Snbeams #include "../hip/ceed-hip-compile.h"
197d8d0e25Snbeams 
207d8d0e25Snbeams //------------------------------------------------------------------------------
217d8d0e25Snbeams // Shared mem kernels
227d8d0e25Snbeams //------------------------------------------------------------------------------
237d8d0e25Snbeams // *INDENT-OFF*
247d8d0e25Snbeams static const char *kernelsShared = QUOTE(
257d8d0e25Snbeams 
267d8d0e25Snbeams //------------------------------------------------------------------------------
277d8d0e25Snbeams // Sum input into output
287d8d0e25Snbeams //------------------------------------------------------------------------------
297d8d0e25Snbeams inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
307d8d0e25Snbeams   for (int i = 0; i < P1D; i++)
317d8d0e25Snbeams     r_V[i] += r_U[i];
327d8d0e25Snbeams }
337d8d0e25Snbeams 
347d8d0e25Snbeams //------------------------------------------------------------------------------
357d8d0e25Snbeams // 1D
367d8d0e25Snbeams //------------------------------------------------------------------------------
377d8d0e25Snbeams 
387d8d0e25Snbeams //------------------------------------------------------------------------------
397d8d0e25Snbeams // Read DoFs
407d8d0e25Snbeams //------------------------------------------------------------------------------
417d8d0e25Snbeams inline __device__ void readDofs1d(const int elem, const int tidx,
427d8d0e25Snbeams                                   const int tidy, const int tidz,const int comp,
437d8d0e25Snbeams                                   const int nelem, const CeedScalar *d_U,
447d8d0e25Snbeams                                   CeedScalar *slice) {
457d8d0e25Snbeams   for (int i = 0; i < P1D; i++)
467d8d0e25Snbeams     slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem];
477d8d0e25Snbeams   for (int i = P1D; i < Q1D; i++)
487d8d0e25Snbeams     slice[i + tidz*T1D] = 0.0;
497d8d0e25Snbeams }
507d8d0e25Snbeams 
517d8d0e25Snbeams //------------------------------------------------------------------------------
527d8d0e25Snbeams // Write DoFs
537d8d0e25Snbeams //------------------------------------------------------------------------------
547d8d0e25Snbeams inline __device__ void writeDofs1d(const int elem, const int tidx,
557d8d0e25Snbeams                                    const int tidy, const int comp,
567d8d0e25Snbeams                                    const int nelem, const CeedScalar &r_V,
577d8d0e25Snbeams                                    CeedScalar *d_V) {
587d8d0e25Snbeams   if (tidx<P1D)
597d8d0e25Snbeams     d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V;
607d8d0e25Snbeams }
617d8d0e25Snbeams 
627d8d0e25Snbeams //------------------------------------------------------------------------------
637d8d0e25Snbeams // Read quadrature point data
647d8d0e25Snbeams //------------------------------------------------------------------------------
657d8d0e25Snbeams inline __device__ void readQuads1d(const int elem, const int tidx,
667d8d0e25Snbeams                                    const int tidy, const int tidz, const int comp,
677d8d0e25Snbeams                                    const int dim, const int nelem,
687d8d0e25Snbeams                                    const CeedScalar *d_U, CeedScalar *slice) {
697d8d0e25Snbeams   for (int i = 0; i < Q1D; i++)
707d8d0e25Snbeams     slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
717d8d0e25Snbeams                             dim*BASIS_NCOMP*nelem*Q1D];
727d8d0e25Snbeams   for (int i = Q1D; i < P1D; i++)
737d8d0e25Snbeams     slice[i + tidz*T1D] = 0.0;
747d8d0e25Snbeams }
757d8d0e25Snbeams 
767d8d0e25Snbeams //------------------------------------------------------------------------------
777d8d0e25Snbeams // Write quadrature point data
787d8d0e25Snbeams //------------------------------------------------------------------------------
797d8d0e25Snbeams inline __device__ void writeQuads1d(const int elem, const int tidx,
807d8d0e25Snbeams                                     const int tidy, const int comp,
817d8d0e25Snbeams                                     const int dim, const int nelem,
827d8d0e25Snbeams                                     const CeedScalar &r_V, CeedScalar *d_V) {
837d8d0e25Snbeams   if (tidx<Q1D)
847d8d0e25Snbeams     d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
857d8d0e25Snbeams }
867d8d0e25Snbeams 
877d8d0e25Snbeams //------------------------------------------------------------------------------
887d8d0e25Snbeams // 1D tensor contraction
897d8d0e25Snbeams //------------------------------------------------------------------------------
907d8d0e25Snbeams inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
917d8d0e25Snbeams                                    const int tidy, const int tidz,
927d8d0e25Snbeams                                    const CeedScalar &U, const CeedScalar *B,
937d8d0e25Snbeams                                    CeedScalar &V) {
947d8d0e25Snbeams   V = 0.0;
957d8d0e25Snbeams   for (int i = 0; i < P1D; ++i)
967d8d0e25Snbeams     V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction
977d8d0e25Snbeams }
987d8d0e25Snbeams 
997d8d0e25Snbeams //------------------------------------------------------------------------------
1007d8d0e25Snbeams // 1D transpose tensor contraction
1017d8d0e25Snbeams //------------------------------------------------------------------------------
1027d8d0e25Snbeams inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
1037d8d0e25Snbeams     const int tidy, const int tidz,
1047d8d0e25Snbeams     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
1057d8d0e25Snbeams   V = 0.0;
1067d8d0e25Snbeams   for (int i = 0; i < Q1D; ++i)
1077d8d0e25Snbeams     V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction
1087d8d0e25Snbeams }
1097d8d0e25Snbeams 
1107d8d0e25Snbeams //------------------------------------------------------------------------------
1117d8d0e25Snbeams // 1D interpolate to quadrature points
1127d8d0e25Snbeams //------------------------------------------------------------------------------
1137d8d0e25Snbeams inline __device__ void interp1d(const CeedInt nelem, const int transpose,
1147d8d0e25Snbeams                                 const CeedScalar *c_B,
1157d8d0e25Snbeams                                 const CeedScalar *__restrict__ d_U,
1167d8d0e25Snbeams                                 CeedScalar *__restrict__ d_V,
1177d8d0e25Snbeams                                 CeedScalar *slice) {
1187d8d0e25Snbeams   CeedScalar r_V;
1197d8d0e25Snbeams   CeedScalar r_t;
1207d8d0e25Snbeams 
1217d8d0e25Snbeams   const int tidx = threadIdx.x;
1227d8d0e25Snbeams   const int tidy = threadIdx.y;
1237d8d0e25Snbeams   const int tidz = threadIdx.z;
1247d8d0e25Snbeams 
1257d8d0e25Snbeams 
1267d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
1277d8d0e25Snbeams        elem += gridDim.x*blockDim.z) {
1287d8d0e25Snbeams     for (int comp = 0; comp < BASIS_NCOMP; comp++) {
1297d8d0e25Snbeams       if (!transpose) {
1307d8d0e25Snbeams         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
1317d8d0e25Snbeams         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
1327d8d0e25Snbeams         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
1337d8d0e25Snbeams       } else {
1347d8d0e25Snbeams         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
1357d8d0e25Snbeams         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
1367d8d0e25Snbeams         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
1377d8d0e25Snbeams       }
1387d8d0e25Snbeams     }
1397d8d0e25Snbeams   }
1407d8d0e25Snbeams }
1417d8d0e25Snbeams 
1427d8d0e25Snbeams //------------------------------------------------------------------------------
1437d8d0e25Snbeams // 1D derivatives at quadrature points
1447d8d0e25Snbeams //------------------------------------------------------------------------------
1457d8d0e25Snbeams inline __device__ void grad1d(const CeedInt nelem, const int transpose,
1467d8d0e25Snbeams                               const CeedScalar *c_B, const CeedScalar *c_G,
1477d8d0e25Snbeams                               const CeedScalar *__restrict__ d_U,
1487d8d0e25Snbeams                               CeedScalar *__restrict__ d_V,
1497d8d0e25Snbeams                               CeedScalar *slice) {
1507d8d0e25Snbeams   CeedScalar r_U;
1517d8d0e25Snbeams   CeedScalar r_V;
1527d8d0e25Snbeams 
1537d8d0e25Snbeams   const int tidx = threadIdx.x;
1547d8d0e25Snbeams   const int tidy = threadIdx.y;
1557d8d0e25Snbeams   const int tidz = threadIdx.z;
1567d8d0e25Snbeams   int dim;
1577d8d0e25Snbeams 
1587d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
1597d8d0e25Snbeams        elem += gridDim.x*blockDim.z) {
1607d8d0e25Snbeams     for(int comp = 0; comp < BASIS_NCOMP; comp++) {
1617d8d0e25Snbeams       if (!transpose) {
1627d8d0e25Snbeams         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
1637d8d0e25Snbeams         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
1647d8d0e25Snbeams         dim = 0;
1657d8d0e25Snbeams         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
1667d8d0e25Snbeams       } else {
1677d8d0e25Snbeams         dim = 0;
1687d8d0e25Snbeams         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
1697d8d0e25Snbeams         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
1707d8d0e25Snbeams         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
1717d8d0e25Snbeams       }
1727d8d0e25Snbeams     }
1737d8d0e25Snbeams   }
1747d8d0e25Snbeams }
1757d8d0e25Snbeams 
1767d8d0e25Snbeams //------------------------------------------------------------------------------
1777d8d0e25Snbeams // 1D Quadrature weights
1787d8d0e25Snbeams //------------------------------------------------------------------------------
1797d8d0e25Snbeams __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
1807d8d0e25Snbeams                          CeedScalar *w) {
1817d8d0e25Snbeams   const int tid = threadIdx.x;
1827d8d0e25Snbeams   const CeedScalar weight = qweight1d[tid];
1837d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
1847d8d0e25Snbeams        elem += gridDim.x*blockDim.y) {
1857d8d0e25Snbeams     const int ind = elem*Q1D + tid;
1867d8d0e25Snbeams     w[ind] = weight;
1877d8d0e25Snbeams   }
1887d8d0e25Snbeams }
1897d8d0e25Snbeams 
1907d8d0e25Snbeams //------------------------------------------------------------------------------
1917d8d0e25Snbeams // 2D
1927d8d0e25Snbeams //------------------------------------------------------------------------------
1937d8d0e25Snbeams 
1947d8d0e25Snbeams //------------------------------------------------------------------------------
1957d8d0e25Snbeams // Read DoFs
1967d8d0e25Snbeams //------------------------------------------------------------------------------
1977d8d0e25Snbeams inline __device__ void readDofs2d(const int elem, const int tidx,
1987d8d0e25Snbeams                                   const int tidy, const int comp,
1997d8d0e25Snbeams                                   const int nelem, const CeedScalar *d_U,
2007d8d0e25Snbeams                                   CeedScalar &U) {
2017d8d0e25Snbeams   U = (tidx<P1D && tidy<P1D) ?
2027d8d0e25Snbeams       d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0;
2037d8d0e25Snbeams }
2047d8d0e25Snbeams 
2057d8d0e25Snbeams //------------------------------------------------------------------------------
2067d8d0e25Snbeams // Write DoFs
2077d8d0e25Snbeams //------------------------------------------------------------------------------
2087d8d0e25Snbeams inline __device__ void writeDofs2d(const int elem, const int tidx,
2097d8d0e25Snbeams                                    const int tidy, const int comp,
2107d8d0e25Snbeams                                    const int nelem, const CeedScalar &r_V,
2117d8d0e25Snbeams                                    CeedScalar *d_V) {
2127d8d0e25Snbeams   if (tidx<P1D && tidy<P1D)
2137d8d0e25Snbeams     d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V;
2147d8d0e25Snbeams }
2157d8d0e25Snbeams 
2167d8d0e25Snbeams //------------------------------------------------------------------------------
2177d8d0e25Snbeams // Read quadrature point data
2187d8d0e25Snbeams //------------------------------------------------------------------------------
2197d8d0e25Snbeams inline __device__ void readQuads2d(const int elem, const int tidx,
2207d8d0e25Snbeams                                    const int tidy, const int comp,
2217d8d0e25Snbeams                                    const int dim, const int nelem,
2227d8d0e25Snbeams                                    const CeedScalar *d_U, CeedScalar &U ) {
2237d8d0e25Snbeams   U = (tidx<Q1D && tidy<Q1D) ?
2247d8d0e25Snbeams       d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
2257d8d0e25Snbeams       dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0;
2267d8d0e25Snbeams }
2277d8d0e25Snbeams 
2287d8d0e25Snbeams //------------------------------------------------------------------------------
2297d8d0e25Snbeams // Write quadrature point data
2307d8d0e25Snbeams //------------------------------------------------------------------------------
2317d8d0e25Snbeams inline __device__ void writeQuads2d(const int elem, const int tidx,
2327d8d0e25Snbeams                                     const int tidy, const int comp,
2337d8d0e25Snbeams                                     const int dim, const int nelem,
2347d8d0e25Snbeams                                     const CeedScalar &r_V, CeedScalar *d_V) {
2357d8d0e25Snbeams   if (tidx<Q1D && tidy<Q1D)
2367d8d0e25Snbeams     d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
2377d8d0e25Snbeams     dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V;
2387d8d0e25Snbeams }
2397d8d0e25Snbeams 
2407d8d0e25Snbeams //------------------------------------------------------------------------------
2417d8d0e25Snbeams // 2D tensor contraction x
2427d8d0e25Snbeams //------------------------------------------------------------------------------
2437d8d0e25Snbeams inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
2447d8d0e25Snbeams                                    const int tidy, const int tidz,
2457d8d0e25Snbeams                                    const CeedScalar &U, const CeedScalar *B,
2467d8d0e25Snbeams                                    CeedScalar &V) {
2477d8d0e25Snbeams   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
2487d8d0e25Snbeams   __syncthreads();
2497d8d0e25Snbeams   V = 0.0;
2507d8d0e25Snbeams   if (tidx < Q1D)
2517d8d0e25Snbeams     for (int i = 0; i < P1D; ++i)
2527d8d0e25Snbeams       V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
2537d8d0e25Snbeams   __syncthreads();
2547d8d0e25Snbeams }
2557d8d0e25Snbeams 
2567d8d0e25Snbeams //------------------------------------------------------------------------------
2577d8d0e25Snbeams // 2D tensor contraction y
2587d8d0e25Snbeams //------------------------------------------------------------------------------
2597d8d0e25Snbeams inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
2607d8d0e25Snbeams                                    const int tidy, const int tidz,
2617d8d0e25Snbeams                                    const CeedScalar &U, const CeedScalar *B,
2627d8d0e25Snbeams                                    CeedScalar &V) {
2637d8d0e25Snbeams   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
2647d8d0e25Snbeams   __syncthreads();
2657d8d0e25Snbeams   V = 0.0;
2667d8d0e25Snbeams   if (tidy < Q1D)
2677d8d0e25Snbeams     for (int i = 0; i < P1D; ++i)
2687d8d0e25Snbeams       V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
2697d8d0e25Snbeams   __syncthreads();
2707d8d0e25Snbeams }
2717d8d0e25Snbeams 
2727d8d0e25Snbeams //------------------------------------------------------------------------------
2737d8d0e25Snbeams // 2D transpose tensor contraction y
2747d8d0e25Snbeams //------------------------------------------------------------------------------
2757d8d0e25Snbeams inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
2767d8d0e25Snbeams     const int tidy, const int tidz,
2777d8d0e25Snbeams     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
2787d8d0e25Snbeams   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
2797d8d0e25Snbeams   __syncthreads();
2807d8d0e25Snbeams   V = 0.0;
2817d8d0e25Snbeams   if (tidy < P1D)
2827d8d0e25Snbeams     for (int i = 0; i < Q1D; ++i)
2837d8d0e25Snbeams       V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
2847d8d0e25Snbeams   __syncthreads();
2857d8d0e25Snbeams }
2867d8d0e25Snbeams 
2877d8d0e25Snbeams //------------------------------------------------------------------------------
2887d8d0e25Snbeams // 2D transpose tensor contraction x
2897d8d0e25Snbeams //------------------------------------------------------------------------------
2907d8d0e25Snbeams inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
2917d8d0e25Snbeams     const int tidy, const int tidz,
2927d8d0e25Snbeams     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
2937d8d0e25Snbeams   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
2947d8d0e25Snbeams   __syncthreads();
2957d8d0e25Snbeams   V = 0.0;
2967d8d0e25Snbeams   if (tidx < P1D)
2977d8d0e25Snbeams     for (int i = 0; i < Q1D; ++i)
2987d8d0e25Snbeams       V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
2997d8d0e25Snbeams   __syncthreads();
3007d8d0e25Snbeams }
3017d8d0e25Snbeams 
3027d8d0e25Snbeams //------------------------------------------------------------------------------
3037d8d0e25Snbeams // 2D interpolate to quadrature points
3047d8d0e25Snbeams //------------------------------------------------------------------------------
3057d8d0e25Snbeams inline __device__ void interp2d(const CeedInt nelem, const int transpose,
3067d8d0e25Snbeams                                 const CeedScalar *c_B,
3077d8d0e25Snbeams                                 const CeedScalar *__restrict__ d_U,
3087d8d0e25Snbeams                                 CeedScalar *__restrict__ d_V,
3097d8d0e25Snbeams                                 CeedScalar *slice) {
3107d8d0e25Snbeams   CeedScalar r_V;
3117d8d0e25Snbeams   CeedScalar r_t;
3127d8d0e25Snbeams 
3137d8d0e25Snbeams   const int tidx = threadIdx.x;
3147d8d0e25Snbeams   const int tidy = threadIdx.y;
3157d8d0e25Snbeams   const int tidz = threadIdx.z;
3167d8d0e25Snbeams   const int blockElem = tidz/BASIS_NCOMP;
3177d8d0e25Snbeams   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
3187d8d0e25Snbeams   const int comp = tidz%BASIS_NCOMP;
3197d8d0e25Snbeams 
3207d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
3217d8d0e25Snbeams        elem += gridDim.x*elemsPerBlock) {
3227d8d0e25Snbeams     const int comp = tidz%BASIS_NCOMP;
3237d8d0e25Snbeams     r_V = 0.0;
3247d8d0e25Snbeams     r_t = 0.0;
3257d8d0e25Snbeams     if (!transpose) {
3267d8d0e25Snbeams       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
3277d8d0e25Snbeams       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
3287d8d0e25Snbeams       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
3297d8d0e25Snbeams       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
3307d8d0e25Snbeams     } else {
3317d8d0e25Snbeams       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
3327d8d0e25Snbeams       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
3337d8d0e25Snbeams       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
3347d8d0e25Snbeams       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
3357d8d0e25Snbeams     }
3367d8d0e25Snbeams   }
3377d8d0e25Snbeams }
3387d8d0e25Snbeams 
3397d8d0e25Snbeams //------------------------------------------------------------------------------
3407d8d0e25Snbeams // 2D derivatives at quadrature points
3417d8d0e25Snbeams //------------------------------------------------------------------------------
3427d8d0e25Snbeams inline __device__ void grad2d(const CeedInt nelem, const int transpose,
3437d8d0e25Snbeams                               const CeedScalar *c_B, const CeedScalar *c_G,
3447d8d0e25Snbeams                               const CeedScalar *__restrict__ d_U,
3457d8d0e25Snbeams                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
3467d8d0e25Snbeams   CeedScalar r_U;
3477d8d0e25Snbeams   CeedScalar r_V;
3487d8d0e25Snbeams   CeedScalar r_t;
3497d8d0e25Snbeams 
3507d8d0e25Snbeams   const int tidx = threadIdx.x;
3517d8d0e25Snbeams   const int tidy = threadIdx.y;
3527d8d0e25Snbeams   const int tidz = threadIdx.z;
3537d8d0e25Snbeams   const int blockElem = tidz/BASIS_NCOMP;
3547d8d0e25Snbeams   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
3557d8d0e25Snbeams   const int comp = tidz%BASIS_NCOMP;
3567d8d0e25Snbeams   int dim;
3577d8d0e25Snbeams 
3587d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
3597d8d0e25Snbeams        elem += gridDim.x*elemsPerBlock) {
3607d8d0e25Snbeams     if (!transpose) {
3617d8d0e25Snbeams       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
3627d8d0e25Snbeams       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
3637d8d0e25Snbeams       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
3647d8d0e25Snbeams       dim = 0;
3657d8d0e25Snbeams       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
3667d8d0e25Snbeams       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
3677d8d0e25Snbeams       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
3687d8d0e25Snbeams       dim = 1;
3697d8d0e25Snbeams       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
3707d8d0e25Snbeams     } else {
3717d8d0e25Snbeams       dim = 0;
3727d8d0e25Snbeams       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
3737d8d0e25Snbeams       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
3747d8d0e25Snbeams       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
3757d8d0e25Snbeams       dim = 1;
3767d8d0e25Snbeams       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
3777d8d0e25Snbeams       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
3787d8d0e25Snbeams       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
3797d8d0e25Snbeams       r_V += r_U;
3807d8d0e25Snbeams       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
3817d8d0e25Snbeams     }
3827d8d0e25Snbeams   }
3837d8d0e25Snbeams }
3847d8d0e25Snbeams 
3857d8d0e25Snbeams //------------------------------------------------------------------------------
3867d8d0e25Snbeams // 2D quadrature weights
3877d8d0e25Snbeams //------------------------------------------------------------------------------
3887d8d0e25Snbeams __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
3897d8d0e25Snbeams                          CeedScalar *w) {
3907d8d0e25Snbeams   const int i = threadIdx.x;
3917d8d0e25Snbeams   const int j = threadIdx.y;
3927d8d0e25Snbeams   const CeedScalar weight = qweight1d[i]*qweight1d[j];
3937d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
3947d8d0e25Snbeams        elem += gridDim.x*blockDim.z) {
3957d8d0e25Snbeams     const int ind = elem*Q1D*Q1D + i + j*Q1D;
3967d8d0e25Snbeams     w[ind] = weight;
3977d8d0e25Snbeams   }
3987d8d0e25Snbeams }
3997d8d0e25Snbeams 
4007d8d0e25Snbeams //------------------------------------------------------------------------------
4017d8d0e25Snbeams // 3D
4027d8d0e25Snbeams //------------------------------------------------------------------------------
4037d8d0e25Snbeams 
4047d8d0e25Snbeams //------------------------------------------------------------------------------
4057d8d0e25Snbeams // Read DoFs
4067d8d0e25Snbeams //------------------------------------------------------------------------------
4077d8d0e25Snbeams inline __device__ void readDofs3d(const int elem, const int tidx,
4087d8d0e25Snbeams                                   const int tidy, const int comp,
4097d8d0e25Snbeams                                   const int nelem, const CeedScalar *d_U,
4107d8d0e25Snbeams                                   CeedScalar *r_U) {
4117d8d0e25Snbeams   for (int i = 0; i < P1D; i++)
4127d8d0e25Snbeams     r_U[i] = (tidx < P1D && tidy < P1D) ?
4137d8d0e25Snbeams               d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
4147d8d0e25Snbeams                   comp*P1D*P1D*P1D*nelem] : 0.0;
4157d8d0e25Snbeams   for (int i = P1D; i < Q1D; i++)
4167d8d0e25Snbeams     r_U[i] = 0.0;
4177d8d0e25Snbeams }
4187d8d0e25Snbeams 
4197d8d0e25Snbeams //------------------------------------------------------------------------------
4207d8d0e25Snbeams // Write DoFs
4217d8d0e25Snbeams //------------------------------------------------------------------------------
4227d8d0e25Snbeams inline __device__ void writeDofs3d(const int elem, const int tidx,
4237d8d0e25Snbeams                                    const int tidy, const int comp,
4247d8d0e25Snbeams                                    const int nelem, const CeedScalar *r_V,
4257d8d0e25Snbeams                                    CeedScalar *d_V) {
4267d8d0e25Snbeams   if (tidx < P1D && tidy < P1D) {
4277d8d0e25Snbeams     for (int i = 0; i < P1D; i++)
4287d8d0e25Snbeams       d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
4297d8d0e25Snbeams           comp*P1D*P1D*P1D*nelem] = r_V[i];
4307d8d0e25Snbeams   }
4317d8d0e25Snbeams }
4327d8d0e25Snbeams 
4337d8d0e25Snbeams //------------------------------------------------------------------------------
4347d8d0e25Snbeams // Read quadrature point data
4357d8d0e25Snbeams //------------------------------------------------------------------------------
4367d8d0e25Snbeams inline __device__ void readQuads3d(const int elem, const int tidx,
4377d8d0e25Snbeams                                    const int tidy, const int comp,
4387d8d0e25Snbeams                                    const int dim, const int nelem,
4397d8d0e25Snbeams                                    const CeedScalar *d_U, CeedScalar *r_U) {
4407d8d0e25Snbeams   for (int i = 0; i < Q1D; i++)
4417d8d0e25Snbeams     r_U[i] = (tidx < Q1D && tidy < Q1D) ?
4427d8d0e25Snbeams               d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
4437d8d0e25Snbeams               comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0;
4447d8d0e25Snbeams   for (int i = Q1D; i < P1D; i++)
4457d8d0e25Snbeams     r_U[i] = 0.0;
4467d8d0e25Snbeams }
4477d8d0e25Snbeams 
4487d8d0e25Snbeams //------------------------------------------------------------------------------
4497d8d0e25Snbeams // Write quadrature point data
4507d8d0e25Snbeams //------------------------------------------------------------------------------
4517d8d0e25Snbeams inline __device__ void writeQuads3d(const int elem, const int tidx,
4527d8d0e25Snbeams                                     const int tidy, const int comp,
4537d8d0e25Snbeams                                     const int dim, const int nelem,
4547d8d0e25Snbeams                                     const CeedScalar *r_V, CeedScalar *d_V) {
4557d8d0e25Snbeams   if (tidx < Q1D && tidy < Q1D) {
4567d8d0e25Snbeams     for (int i = 0; i < Q1D; i++)
4577d8d0e25Snbeams       d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
4587d8d0e25Snbeams           dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i];
4597d8d0e25Snbeams   }
4607d8d0e25Snbeams }
4617d8d0e25Snbeams 
4627d8d0e25Snbeams //------------------------------------------------------------------------------
4637d8d0e25Snbeams // 3D tensor contract x
4647d8d0e25Snbeams //------------------------------------------------------------------------------
4657d8d0e25Snbeams inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
4667d8d0e25Snbeams                                    const int tidy, const int tidz,
4677d8d0e25Snbeams                                    const CeedScalar *U,
4687d8d0e25Snbeams                                    const CeedScalar *B,
4697d8d0e25Snbeams                                    CeedScalar *V) {
4707d8d0e25Snbeams   for (int k = 0; k < P1D; ++k) {
4717d8d0e25Snbeams     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
4727d8d0e25Snbeams     __syncthreads();
4737d8d0e25Snbeams     V[k] = 0.0;
4747d8d0e25Snbeams     if (tidx < Q1D && tidy < P1D)
4757d8d0e25Snbeams       for (int i = 0; i < P1D; ++i)
4767d8d0e25Snbeams         V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
4777d8d0e25Snbeams     __syncthreads();
4787d8d0e25Snbeams   }
4797d8d0e25Snbeams }
4807d8d0e25Snbeams 
4817d8d0e25Snbeams //------------------------------------------------------------------------------
4827d8d0e25Snbeams // 3D tensor contract y
4837d8d0e25Snbeams //------------------------------------------------------------------------------
4847d8d0e25Snbeams inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
4857d8d0e25Snbeams                                    const int tidy, const int tidz,
4867d8d0e25Snbeams                                    const CeedScalar *U,
4877d8d0e25Snbeams                                    const CeedScalar *B,
4887d8d0e25Snbeams                                    CeedScalar *V) {
4897d8d0e25Snbeams   for (int k = 0; k < P1D; ++k) {
4907d8d0e25Snbeams     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
4917d8d0e25Snbeams     __syncthreads();
4927d8d0e25Snbeams     V[k] = 0.0;
4937d8d0e25Snbeams     if (tidx < Q1D && tidy < Q1D)
4947d8d0e25Snbeams       for (int i = 0; i < P1D; ++i)
4957d8d0e25Snbeams         V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
4967d8d0e25Snbeams     __syncthreads();
4977d8d0e25Snbeams   }
4987d8d0e25Snbeams }
4997d8d0e25Snbeams 
5007d8d0e25Snbeams //------------------------------------------------------------------------------
5017d8d0e25Snbeams // 3D tensor contract z
5027d8d0e25Snbeams //------------------------------------------------------------------------------
5037d8d0e25Snbeams inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
5047d8d0e25Snbeams                                    const int tidy, const int tidz,
5057d8d0e25Snbeams                                    const CeedScalar *U,
5067d8d0e25Snbeams                                    const CeedScalar *B,
5077d8d0e25Snbeams                                    CeedScalar *V) {
5087d8d0e25Snbeams   for (int k = 0; k < Q1D; ++k) {
5097d8d0e25Snbeams     V[k] = 0.0;
5107d8d0e25Snbeams     if (tidx < Q1D && tidy < Q1D)
5117d8d0e25Snbeams       for (int i = 0; i < P1D; ++i)
5127d8d0e25Snbeams         V[k] += B[i + k*P1D] * U[i]; // Contract z direction
5137d8d0e25Snbeams   }
5147d8d0e25Snbeams   for (int k = Q1D; k < P1D; ++k)
5157d8d0e25Snbeams     V[k] = 0.0;
5167d8d0e25Snbeams }
5177d8d0e25Snbeams 
5187d8d0e25Snbeams //------------------------------------------------------------------------------
5197d8d0e25Snbeams // 3D transpose tensor contract z
5207d8d0e25Snbeams //------------------------------------------------------------------------------
5217d8d0e25Snbeams inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
5227d8d0e25Snbeams                                             const int tidy, const int tidz,
5237d8d0e25Snbeams                                             const CeedScalar *U,
5247d8d0e25Snbeams                                             const CeedScalar *B,
5257d8d0e25Snbeams                                             CeedScalar *V) {
5267d8d0e25Snbeams   for (int k = 0; k < P1D; ++k) {
5277d8d0e25Snbeams     V[k] = 0.0;
5287d8d0e25Snbeams     if (tidx < Q1D && tidy < Q1D)
5297d8d0e25Snbeams       for (int i = 0; i < Q1D; ++i)
5307d8d0e25Snbeams         V[k] += B[k + i*P1D] * U[i]; // Contract z direction
5317d8d0e25Snbeams   }
5327d8d0e25Snbeams   for (int k = P1D; k < Q1D; ++k)
5337d8d0e25Snbeams     V[k] = 0.0;
5347d8d0e25Snbeams }
5357d8d0e25Snbeams 
5367d8d0e25Snbeams //------------------------------------------------------------------------------
5377d8d0e25Snbeams // 3D transpose tensor contract y
5387d8d0e25Snbeams //------------------------------------------------------------------------------
5397d8d0e25Snbeams inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
5407d8d0e25Snbeams                                             const int tidy, const int tidz,
5417d8d0e25Snbeams                                             const CeedScalar *U,
5427d8d0e25Snbeams                                             const CeedScalar *B,
5437d8d0e25Snbeams                                             CeedScalar *V) {
5447d8d0e25Snbeams   for (int k = 0; k < P1D; ++k) {
5457d8d0e25Snbeams     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
5467d8d0e25Snbeams     __syncthreads();
5477d8d0e25Snbeams     V[k] = 0.0;
5487d8d0e25Snbeams     if (tidx < Q1D && tidy < P1D)
5497d8d0e25Snbeams       for (int i = 0; i < Q1D; ++i)
5507d8d0e25Snbeams         V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
5517d8d0e25Snbeams     __syncthreads();
5527d8d0e25Snbeams   }
5537d8d0e25Snbeams }
5547d8d0e25Snbeams 
5557d8d0e25Snbeams //------------------------------------------------------------------------------
5567d8d0e25Snbeams // 3D transpose tensor contract x
5577d8d0e25Snbeams //------------------------------------------------------------------------------
5587d8d0e25Snbeams inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
5597d8d0e25Snbeams                                             const int tidy, const int tidz,
5607d8d0e25Snbeams                                             const CeedScalar *U,
5617d8d0e25Snbeams                                             const CeedScalar *B,
5627d8d0e25Snbeams                                             CeedScalar *V) {
5637d8d0e25Snbeams   for (int k = 0; k < P1D; ++k) {
5647d8d0e25Snbeams     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
5657d8d0e25Snbeams     __syncthreads();
5667d8d0e25Snbeams     V[k] = 0.0;
5677d8d0e25Snbeams     if (tidx < P1D && tidy < P1D)
5687d8d0e25Snbeams       for (int i = 0; i < Q1D; ++i)
5697d8d0e25Snbeams         V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
5707d8d0e25Snbeams     __syncthreads();
5717d8d0e25Snbeams   }
5727d8d0e25Snbeams }
5737d8d0e25Snbeams 
5747d8d0e25Snbeams //------------------------------------------------------------------------------
5757d8d0e25Snbeams // 3D interpolate to quadrature points
5767d8d0e25Snbeams //------------------------------------------------------------------------------
5777d8d0e25Snbeams inline __device__ void interp3d(const CeedInt nelem, const int transpose,
5787d8d0e25Snbeams                                 const CeedScalar *c_B,
5797d8d0e25Snbeams                                 const CeedScalar *__restrict__ d_U,
5807d8d0e25Snbeams                                 CeedScalar *__restrict__ d_V,
5817d8d0e25Snbeams                                 CeedScalar *slice) {
5827d8d0e25Snbeams   CeedScalar r_V[T1D];
5837d8d0e25Snbeams   CeedScalar r_t[T1D];
5847d8d0e25Snbeams 
5857d8d0e25Snbeams   const int tidx = threadIdx.x;
5867d8d0e25Snbeams   const int tidy = threadIdx.y;
5877d8d0e25Snbeams   const int tidz = threadIdx.z;
5887d8d0e25Snbeams   const int blockElem = tidz/BASIS_NCOMP;
5897d8d0e25Snbeams   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
5907d8d0e25Snbeams   const int comp = tidz%BASIS_NCOMP;
5917d8d0e25Snbeams 
5927d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
5937d8d0e25Snbeams        elem += gridDim.x*elemsPerBlock) {
5947d8d0e25Snbeams     for (int i = 0; i < T1D; ++i) {
5957d8d0e25Snbeams       r_V[i] = 0.0;
5967d8d0e25Snbeams       r_t[i] = 0.0;
5977d8d0e25Snbeams     }
5987d8d0e25Snbeams     if (!transpose) {
5997d8d0e25Snbeams       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
6007d8d0e25Snbeams       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
6017d8d0e25Snbeams       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
6027d8d0e25Snbeams       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
6037d8d0e25Snbeams       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
6047d8d0e25Snbeams     } else {
6057d8d0e25Snbeams       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
6067d8d0e25Snbeams       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
6077d8d0e25Snbeams       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
6087d8d0e25Snbeams       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
6097d8d0e25Snbeams       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
6107d8d0e25Snbeams     }
6117d8d0e25Snbeams   }
6127d8d0e25Snbeams }
6137d8d0e25Snbeams 
6147d8d0e25Snbeams //------------------------------------------------------------------------------
6157d8d0e25Snbeams // 3D derivatives at quadrature points
6167d8d0e25Snbeams //------------------------------------------------------------------------------
6177d8d0e25Snbeams inline __device__ void grad3d(const CeedInt nelem, const int transpose,
6187d8d0e25Snbeams                               const CeedScalar *c_B, const CeedScalar *c_G,
6197d8d0e25Snbeams                               const CeedScalar *__restrict__ d_U,
6207d8d0e25Snbeams                               CeedScalar *__restrict__ d_V,
6217d8d0e25Snbeams                               CeedScalar *slice) {
6227d8d0e25Snbeams   // Use P1D for one of these
6237d8d0e25Snbeams   CeedScalar r_U[T1D];
6247d8d0e25Snbeams   CeedScalar r_V[T1D];
6257d8d0e25Snbeams   CeedScalar r_t[T1D];
6267d8d0e25Snbeams 
6277d8d0e25Snbeams   const int tidx = threadIdx.x;
6287d8d0e25Snbeams   const int tidy = threadIdx.y;
6297d8d0e25Snbeams   const int tidz = threadIdx.z;
6307d8d0e25Snbeams   const int blockElem = tidz/BASIS_NCOMP;
6317d8d0e25Snbeams   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
6327d8d0e25Snbeams   const int comp = tidz%BASIS_NCOMP;
6337d8d0e25Snbeams   int dim;
6347d8d0e25Snbeams 
6357d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
6367d8d0e25Snbeams        elem += gridDim.x*elemsPerBlock) {
6377d8d0e25Snbeams     for (int i = 0; i < T1D; ++i) {
6387d8d0e25Snbeams       r_U[i] = 0.0;
6397d8d0e25Snbeams       r_V[i] = 0.0;
6407d8d0e25Snbeams       r_t[i] = 0.0;
6417d8d0e25Snbeams     }
6427d8d0e25Snbeams     if (!transpose) {
6437d8d0e25Snbeams       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
6447d8d0e25Snbeams       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
6457d8d0e25Snbeams       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
6467d8d0e25Snbeams       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
6477d8d0e25Snbeams       dim = 0;
6487d8d0e25Snbeams       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
6497d8d0e25Snbeams       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
6507d8d0e25Snbeams       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
6517d8d0e25Snbeams       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
6527d8d0e25Snbeams       dim = 1;
6537d8d0e25Snbeams       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
6547d8d0e25Snbeams       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
6557d8d0e25Snbeams       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
6567d8d0e25Snbeams       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
6577d8d0e25Snbeams       dim = 2;
6587d8d0e25Snbeams       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
6597d8d0e25Snbeams     } else {
6607d8d0e25Snbeams       dim = 0;
6617d8d0e25Snbeams       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
6627d8d0e25Snbeams       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
6637d8d0e25Snbeams       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
6647d8d0e25Snbeams       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
6657d8d0e25Snbeams       dim = 1;
6667d8d0e25Snbeams       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
6677d8d0e25Snbeams       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
6687d8d0e25Snbeams       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
6697d8d0e25Snbeams       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
6707d8d0e25Snbeams       add(r_V, r_t);
6717d8d0e25Snbeams       dim = 2;
6727d8d0e25Snbeams       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
6737d8d0e25Snbeams       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
6747d8d0e25Snbeams       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
6757d8d0e25Snbeams       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
6767d8d0e25Snbeams       add(r_V, r_t);
6777d8d0e25Snbeams       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
6787d8d0e25Snbeams     }
6797d8d0e25Snbeams   }
6807d8d0e25Snbeams }
6817d8d0e25Snbeams 
6827d8d0e25Snbeams //------------------------------------------------------------------------------
6837d8d0e25Snbeams // 3D quadrature weights
6847d8d0e25Snbeams //------------------------------------------------------------------------------
6857d8d0e25Snbeams __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
6867d8d0e25Snbeams                          CeedScalar *w) {
6877d8d0e25Snbeams   const int i = threadIdx.x;
6887d8d0e25Snbeams   const int j = threadIdx.y;
6897d8d0e25Snbeams   const int k = threadIdx.z;
6907d8d0e25Snbeams   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
6917d8d0e25Snbeams   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
6927d8d0e25Snbeams     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
6937d8d0e25Snbeams     w[ind] = weight;
6947d8d0e25Snbeams   }
6957d8d0e25Snbeams }
6967d8d0e25Snbeams 
6977d8d0e25Snbeams 
6987d8d0e25Snbeams //------------------------------------------------------------------------------
6997d8d0e25Snbeams // Basis kernels
7007d8d0e25Snbeams //------------------------------------------------------------------------------
7017d8d0e25Snbeams 
7027d8d0e25Snbeams //------------------------------------------------------------------------------
7037d8d0e25Snbeams // Interp kernel by dim
7047d8d0e25Snbeams //------------------------------------------------------------------------------
705*9e31c45bSnbeams extern "C" __launch_bounds__(INTERP_BLKSIZE) __global__ void interp(
706*9e31c45bSnbeams                                   const CeedInt nelem, const int transpose,
7077d8d0e25Snbeams                                   const CeedScalar *c_B,
7087d8d0e25Snbeams                                   const CeedScalar *__restrict__ d_U,
7097d8d0e25Snbeams                                   CeedScalar *__restrict__ d_V) {
7107d8d0e25Snbeams   HIP_DYNAMIC_SHARED( double, slice)
7117d8d0e25Snbeams   if (BASIS_DIM == 1) {
7127d8d0e25Snbeams     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
7137d8d0e25Snbeams   } else if (BASIS_DIM == 2) {
7147d8d0e25Snbeams     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
7157d8d0e25Snbeams   } else if (BASIS_DIM == 3) {
7167d8d0e25Snbeams     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
7177d8d0e25Snbeams   }
7187d8d0e25Snbeams }
7197d8d0e25Snbeams 
7207d8d0e25Snbeams //------------------------------------------------------------------------------
7217d8d0e25Snbeams // Grad kernel by dim
7227d8d0e25Snbeams //------------------------------------------------------------------------------
723*9e31c45bSnbeams extern "C" __launch_bounds__(GRAD_BLKSIZE) __global__ void grad(const CeedInt nelem,
724*9e31c45bSnbeams                                 const int transpose,
7257d8d0e25Snbeams                                 const CeedScalar *c_B, const CeedScalar *c_G,
7267d8d0e25Snbeams                                 const CeedScalar *__restrict__ d_U,
7277d8d0e25Snbeams                                 CeedScalar *__restrict__ d_V) {
7287d8d0e25Snbeams   HIP_DYNAMIC_SHARED( double, slice)
7297d8d0e25Snbeams   if (BASIS_DIM == 1) {
7307d8d0e25Snbeams     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
7317d8d0e25Snbeams   } else if (BASIS_DIM == 2) {
7327d8d0e25Snbeams     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
7337d8d0e25Snbeams   } else if (BASIS_DIM == 3) {
7347d8d0e25Snbeams     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
7357d8d0e25Snbeams   }
7367d8d0e25Snbeams }
7377d8d0e25Snbeams 
7387d8d0e25Snbeams //------------------------------------------------------------------------------
7397d8d0e25Snbeams // Weight kernels by dim
7407d8d0e25Snbeams //------------------------------------------------------------------------------
741*9e31c45bSnbeams extern "C" __launch_bounds__(WEIGHT_BLKSIZE) __global__ void weight(const CeedInt nelem,
7427d8d0e25Snbeams                                   const CeedScalar *__restrict__ qweight1d,
7437d8d0e25Snbeams                                   CeedScalar *__restrict__ v) {
7447d8d0e25Snbeams   if (BASIS_DIM == 1) {
7457d8d0e25Snbeams     weight1d(nelem, qweight1d, v);
7467d8d0e25Snbeams   } else if (BASIS_DIM == 2) {
7477d8d0e25Snbeams     weight2d(nelem, qweight1d, v);
7487d8d0e25Snbeams   } else if (BASIS_DIM == 3) {
7497d8d0e25Snbeams     weight3d(nelem, qweight1d, v);
7507d8d0e25Snbeams   }
7517d8d0e25Snbeams }
7527d8d0e25Snbeams 
7537d8d0e25Snbeams );
7547d8d0e25Snbeams // *INDENT-ON*
7557d8d0e25Snbeams 
7567d8d0e25Snbeams //------------------------------------------------------------------------------
757*9e31c45bSnbeams // Compute a block size based on required minimum threads
758*9e31c45bSnbeams //------------------------------------------------------------------------------
759*9e31c45bSnbeams static CeedInt ComputeBlockSizeFromRequirement(const CeedInt required) {
760*9e31c45bSnbeams   CeedInt maxSize = 1024;    // Max total threads per block
761*9e31c45bSnbeams   CeedInt currentSize = 64;  // Start with one group
762*9e31c45bSnbeams 
763*9e31c45bSnbeams   while(currentSize < maxSize) {
764*9e31c45bSnbeams     if (currentSize > required)
765*9e31c45bSnbeams       break;
766*9e31c45bSnbeams     else
767*9e31c45bSnbeams       currentSize = currentSize * 2;
768*9e31c45bSnbeams   }
769*9e31c45bSnbeams   return currentSize;
770*9e31c45bSnbeams }
771*9e31c45bSnbeams 
772*9e31c45bSnbeams //------------------------------------------------------------------------------
773*9e31c45bSnbeams // Compute required thread block sizes for basis kernels given P, Q, dim, and
774*9e31c45bSnbeams // ncomp
775*9e31c45bSnbeams //------------------------------------------------------------------------------
776*9e31c45bSnbeams static int ComputeBasisThreadBlockSizes(const CeedInt dim, const CeedInt P1d,
777*9e31c45bSnbeams                                         const CeedInt Q1d,
778*9e31c45bSnbeams                                         const CeedInt ncomp, CeedInt *blksizes) {
779*9e31c45bSnbeams 
780*9e31c45bSnbeams   // Note that this will use the same block sizes for all dimensions when compiling,
781*9e31c45bSnbeams   // but as each basis object is defined for a particular dimension, we will never
782*9e31c45bSnbeams   // call any kernels except the ones for the dimension for which we have computed the
783*9e31c45bSnbeams   // block sizes.
784*9e31c45bSnbeams   const CeedInt thread1d = CeedIntMax(P1d, Q1d);
785*9e31c45bSnbeams   switch (dim) {
786*9e31c45bSnbeams   case 1: {
787*9e31c45bSnbeams     // Interp kernels:
788*9e31c45bSnbeams     blksizes[0] = 256;
789*9e31c45bSnbeams 
790*9e31c45bSnbeams     // Grad kernels:
791*9e31c45bSnbeams     blksizes[1] = 256;
792*9e31c45bSnbeams 
793*9e31c45bSnbeams     // Weight kernels:
794*9e31c45bSnbeams     blksizes[2] = 256;
795*9e31c45bSnbeams 
796*9e31c45bSnbeams   } break;
797*9e31c45bSnbeams   case 2: {
798*9e31c45bSnbeams     // Interp kernels:
799*9e31c45bSnbeams     CeedInt required = thread1d * thread1d * ncomp;
800*9e31c45bSnbeams     blksizes[0]  = ComputeBlockSizeFromRequirement(required);
801*9e31c45bSnbeams 
802*9e31c45bSnbeams     // Grad kernels: currently use same required minimum threads
803*9e31c45bSnbeams     blksizes[1]  = ComputeBlockSizeFromRequirement(required);
804*9e31c45bSnbeams 
805*9e31c45bSnbeams     // Weight kernels:
806*9e31c45bSnbeams     required = CeedIntMax(64, Q1d * Q1d);
807*9e31c45bSnbeams     blksizes[2]  = ComputeBlockSizeFromRequirement(required);
808*9e31c45bSnbeams 
809*9e31c45bSnbeams   } break;
810*9e31c45bSnbeams   case 3: {
811*9e31c45bSnbeams     // Interp kernels:
812*9e31c45bSnbeams     CeedInt required = thread1d * thread1d * ncomp;
813*9e31c45bSnbeams     blksizes[0]  = ComputeBlockSizeFromRequirement(required);
814*9e31c45bSnbeams 
815*9e31c45bSnbeams     // Grad kernels: currently use same required minimum threads
816*9e31c45bSnbeams     blksizes[1]  = ComputeBlockSizeFromRequirement(required);
817*9e31c45bSnbeams 
818*9e31c45bSnbeams     // Weight kernels:
819*9e31c45bSnbeams     required = Q1d * Q1d * Q1d;
820*9e31c45bSnbeams     blksizes[2]  = ComputeBlockSizeFromRequirement(required);
821*9e31c45bSnbeams   }
822*9e31c45bSnbeams   }
823*9e31c45bSnbeams 
824*9e31c45bSnbeams   return 0;
825*9e31c45bSnbeams }
826*9e31c45bSnbeams 
827*9e31c45bSnbeams //------------------------------------------------------------------------------
8287d8d0e25Snbeams // Device initalization
8297d8d0e25Snbeams //------------------------------------------------------------------------------
8307d8d0e25Snbeams int CeedHipInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
8317d8d0e25Snbeams                       CeedScalar **c_B);
8327d8d0e25Snbeams int CeedHipInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
8337d8d0e25Snbeams                           CeedInt Q1d, CeedScalar **c_B_ptr,
8347d8d0e25Snbeams                           CeedScalar **c_G_ptr);
8357d8d0e25Snbeams 
8367d8d0e25Snbeams //------------------------------------------------------------------------------
8377d8d0e25Snbeams // Apply basis
8387d8d0e25Snbeams //------------------------------------------------------------------------------
8397d8d0e25Snbeams int CeedBasisApplyTensor_Hip_shared(CeedBasis basis, const CeedInt nelem,
8407d8d0e25Snbeams                                     CeedTransposeMode tmode,
8417d8d0e25Snbeams                                     CeedEvalMode emode, CeedVector u,
8427d8d0e25Snbeams                                     CeedVector v) {
8437d8d0e25Snbeams   int ierr;
8447d8d0e25Snbeams   Ceed ceed;
8457d8d0e25Snbeams   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
8467d8d0e25Snbeams   Ceed_Hip_shared *ceed_Hip;
8477d8d0e25Snbeams   CeedGetData(ceed, &ceed_Hip); CeedChk(ierr);
8487d8d0e25Snbeams   CeedBasis_Hip_shared *data;
8497d8d0e25Snbeams   CeedBasisGetData(basis, &data); CeedChk(ierr);
8507d8d0e25Snbeams   const CeedInt transpose = tmode == CEED_TRANSPOSE;
8517d8d0e25Snbeams   CeedInt dim, ncomp;
8527d8d0e25Snbeams   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
8537d8d0e25Snbeams   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
8547d8d0e25Snbeams 
8557d8d0e25Snbeams   // Read vectors
8567d8d0e25Snbeams   const CeedScalar *d_u;
8577d8d0e25Snbeams   CeedScalar *d_v;
8587d8d0e25Snbeams   if (emode != CEED_EVAL_WEIGHT) {
8597d8d0e25Snbeams     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
8607d8d0e25Snbeams   }
8617d8d0e25Snbeams   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
8627d8d0e25Snbeams 
8637d8d0e25Snbeams   // Clear v for transpose mode
8647d8d0e25Snbeams   if (tmode == CEED_TRANSPOSE) {
8657d8d0e25Snbeams     CeedInt length;
8667d8d0e25Snbeams     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
8677d8d0e25Snbeams     ierr = hipMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
8687d8d0e25Snbeams   }
8697d8d0e25Snbeams 
8707d8d0e25Snbeams   // Apply basis operation
8717d8d0e25Snbeams   switch (emode) {
8727d8d0e25Snbeams   case CEED_EVAL_INTERP: {
8737d8d0e25Snbeams     CeedInt P1d, Q1d;
874*9e31c45bSnbeams     CeedInt blksize = data->blksizes[0];
8757d8d0e25Snbeams     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
8767d8d0e25Snbeams     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
8777d8d0e25Snbeams     CeedInt thread1d = CeedIntMax(Q1d, P1d);
8787d8d0e25Snbeams     ierr = CeedHipInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
8797d8d0e25Snbeams     CeedChk(ierr);
8807d8d0e25Snbeams     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
8817d8d0e25Snbeams                           &d_u, &d_v
8827d8d0e25Snbeams                          };
8837d8d0e25Snbeams     if (dim == 1) {
884e7ea6884Snbeams       CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64;
8857d8d0e25Snbeams       elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1;
8867d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8877d8d0e25Snbeams                                              ? 1 : 0 );
8887d8d0e25Snbeams       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
8897d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, 1,
8907d8d0e25Snbeams                                        elemsPerBlock, sharedMem,
8917d8d0e25Snbeams                                        interpargs); CeedChk(ierr);
8927d8d0e25Snbeams     } else if (dim == 2) {
893*9e31c45bSnbeams       // Check if required threads is small enough to do multiple elems
894*9e31c45bSnbeams       const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1);
8957d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8967d8d0e25Snbeams                                              ? 1 : 0 );
8977d8d0e25Snbeams       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
8987d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d,
8997d8d0e25Snbeams                                        ncomp*elemsPerBlock, sharedMem,
9007d8d0e25Snbeams                                        interpargs); CeedChk(ierr);
9017d8d0e25Snbeams     } else if (dim == 3) {
9027d8d0e25Snbeams       CeedInt elemsPerBlock = 1;
9037d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
9047d8d0e25Snbeams                                              ? 1 : 0 );
9057d8d0e25Snbeams       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
9067d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d,
9077d8d0e25Snbeams                                        ncomp*elemsPerBlock, sharedMem,
9087d8d0e25Snbeams                                        interpargs); CeedChk(ierr);
9097d8d0e25Snbeams     }
9107d8d0e25Snbeams   } break;
9117d8d0e25Snbeams   case CEED_EVAL_GRAD: {
9127d8d0e25Snbeams     CeedInt P1d, Q1d;
913*9e31c45bSnbeams     CeedInt blksize = data->blksizes[1];
9147d8d0e25Snbeams     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
9157d8d0e25Snbeams     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
9167d8d0e25Snbeams     CeedInt thread1d = CeedIntMax(Q1d, P1d);
9177d8d0e25Snbeams     ierr = CeedHipInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
9187d8d0e25Snbeams                                  Q1d, &data->c_B, &data->c_G);
9197d8d0e25Snbeams     CeedChk(ierr);
9207d8d0e25Snbeams     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
9217d8d0e25Snbeams                         &data->c_G, &d_u, &d_v
9227d8d0e25Snbeams                        };
9237d8d0e25Snbeams     if (dim == 1) {
924e7ea6884Snbeams       CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64;
9257d8d0e25Snbeams       elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1;
9267d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
9277d8d0e25Snbeams                                              ? 1 : 0 );
9287d8d0e25Snbeams       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
9297d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, 1,
9307d8d0e25Snbeams                                        elemsPerBlock, sharedMem, gradargs);
9317d8d0e25Snbeams       CeedChk(ierr);
9327d8d0e25Snbeams     } else if (dim == 2) {
933*9e31c45bSnbeams       // Check if required threads is small enough to do multiple elems
934*9e31c45bSnbeams       const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1);
9357d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
9367d8d0e25Snbeams                                              ? 1 : 0 );
9377d8d0e25Snbeams       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
9387d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d,
9397d8d0e25Snbeams                                        ncomp*elemsPerBlock, sharedMem,
9407d8d0e25Snbeams                                        gradargs); CeedChk(ierr);
9417d8d0e25Snbeams     } else if (dim == 3) {
9427d8d0e25Snbeams       CeedInt elemsPerBlock = 1;
9437d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
9447d8d0e25Snbeams                                              ? 1 : 0 );
9457d8d0e25Snbeams       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
9467d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d,
9477d8d0e25Snbeams                                        ncomp*elemsPerBlock, sharedMem,
9487d8d0e25Snbeams                                        gradargs); CeedChk(ierr);
9497d8d0e25Snbeams     }
9507d8d0e25Snbeams   } break;
9517d8d0e25Snbeams   case CEED_EVAL_WEIGHT: {
9527d8d0e25Snbeams     CeedInt Q1d;
953*9e31c45bSnbeams     CeedInt blksize = data->blksizes[2];
9547d8d0e25Snbeams     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
9557d8d0e25Snbeams     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
9567d8d0e25Snbeams     if (dim == 1) {
957*9e31c45bSnbeams       const CeedInt optElems = blksize/Q1d;
9587d8d0e25Snbeams       const CeedInt elemsPerBlock = optElems>0?optElems:1;
9597d8d0e25Snbeams       const CeedInt gridsize = nelem/elemsPerBlock + ( (
9607d8d0e25Snbeams                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
9617d8d0e25Snbeams       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d,
9627d8d0e25Snbeams                                  elemsPerBlock, 1, weightargs);
9637d8d0e25Snbeams       CeedChk(ierr);
9647d8d0e25Snbeams     } else if (dim == 2) {
965*9e31c45bSnbeams       const CeedInt optElems = blksize/(Q1d*Q1d);
9667d8d0e25Snbeams       const CeedInt elemsPerBlock = optElems>0?optElems:1;
9677d8d0e25Snbeams       const CeedInt gridsize = nelem/elemsPerBlock + ( (
9687d8d0e25Snbeams                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
9697d8d0e25Snbeams       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d,
9707d8d0e25Snbeams                                  elemsPerBlock, weightargs);
9717d8d0e25Snbeams       CeedChk(ierr);
9727d8d0e25Snbeams     } else if (dim == 3) {
9737d8d0e25Snbeams       const CeedInt gridsize = nelem;
9747d8d0e25Snbeams       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
9757d8d0e25Snbeams                                  weightargs);
9767d8d0e25Snbeams       CeedChk(ierr);
9777d8d0e25Snbeams     }
9787d8d0e25Snbeams   } break;
9797d8d0e25Snbeams   // LCOV_EXCL_START
9807d8d0e25Snbeams   // Evaluate the divergence to/from the quadrature points
9817d8d0e25Snbeams   case CEED_EVAL_DIV:
9827d8d0e25Snbeams     return CeedError(ceed, 1, "CEED_EVAL_DIV not supported");
9837d8d0e25Snbeams   // Evaluate the curl to/from the quadrature points
9847d8d0e25Snbeams   case CEED_EVAL_CURL:
9857d8d0e25Snbeams     return CeedError(ceed, 1, "CEED_EVAL_CURL not supported");
9867d8d0e25Snbeams   // Take no action, BasisApply should not have been called
9877d8d0e25Snbeams   case CEED_EVAL_NONE:
9887d8d0e25Snbeams     return CeedError(ceed, 1,
9897d8d0e25Snbeams                      "CEED_EVAL_NONE does not make sense in this context");
9907d8d0e25Snbeams     // LCOV_EXCL_STOP
9917d8d0e25Snbeams   }
9927d8d0e25Snbeams 
9937d8d0e25Snbeams   // Restore vectors
9947d8d0e25Snbeams   if (emode != CEED_EVAL_WEIGHT) {
9957d8d0e25Snbeams     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
9967d8d0e25Snbeams   }
9977d8d0e25Snbeams   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
9987d8d0e25Snbeams   return 0;
9997d8d0e25Snbeams }
10007d8d0e25Snbeams 
10017d8d0e25Snbeams //------------------------------------------------------------------------------
10027d8d0e25Snbeams // Destroy basis
10037d8d0e25Snbeams //------------------------------------------------------------------------------
10047d8d0e25Snbeams static int CeedBasisDestroy_Hip_shared(CeedBasis basis) {
10057d8d0e25Snbeams   int ierr;
10067d8d0e25Snbeams   Ceed ceed;
10077d8d0e25Snbeams   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
10087d8d0e25Snbeams 
10097d8d0e25Snbeams   CeedBasis_Hip_shared *data;
10107d8d0e25Snbeams   ierr = CeedBasisGetData(basis, &data); CeedChk(ierr);
10117d8d0e25Snbeams 
10127d8d0e25Snbeams   CeedChk_Hip(ceed, hipModuleUnload(data->module));
10137d8d0e25Snbeams 
10147d8d0e25Snbeams   ierr = hipFree(data->d_qweight1d); CeedChk_Hip(ceed, ierr);
10157d8d0e25Snbeams   ierr = hipFree(data->d_interp1d); CeedChk_Hip(ceed, ierr);
10167d8d0e25Snbeams   ierr = hipFree(data->d_grad1d); CeedChk_Hip(ceed, ierr);
10177d8d0e25Snbeams   ierr = hipFree(data->d_collograd1d); CeedChk_Hip(ceed, ierr);
10187d8d0e25Snbeams 
10197d8d0e25Snbeams   ierr = CeedFree(&data); CeedChk(ierr);
10207d8d0e25Snbeams 
10217d8d0e25Snbeams   return 0;
10227d8d0e25Snbeams }
10237d8d0e25Snbeams 
10247d8d0e25Snbeams //------------------------------------------------------------------------------
10257d8d0e25Snbeams // Create tensor basis
10267d8d0e25Snbeams //------------------------------------------------------------------------------
10277d8d0e25Snbeams int CeedBasisCreateTensorH1_Hip_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
10287d8d0e25Snbeams                                        const CeedScalar *interp1d,
10297d8d0e25Snbeams                                        const CeedScalar *grad1d,
10307d8d0e25Snbeams                                        const CeedScalar *qref1d,
10317d8d0e25Snbeams                                        const CeedScalar *qweight1d,
10327d8d0e25Snbeams                                        CeedBasis basis) {
10337d8d0e25Snbeams   int ierr;
10347d8d0e25Snbeams   Ceed ceed;
10357d8d0e25Snbeams   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
10367d8d0e25Snbeams   CeedBasis_Hip_shared *data;
10377d8d0e25Snbeams   ierr = CeedCalloc(1, &data); CeedChk(ierr);
10387d8d0e25Snbeams 
10397d8d0e25Snbeams   // Copy basis data to GPU
10407d8d0e25Snbeams   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
10417d8d0e25Snbeams   ierr = hipMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Hip(ceed, ierr);
10427d8d0e25Snbeams   ierr = hipMemcpy(data->d_qweight1d, qweight1d, qBytes,
10437d8d0e25Snbeams                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
10447d8d0e25Snbeams 
10457d8d0e25Snbeams   const CeedInt iBytes = qBytes * P1d;
10467d8d0e25Snbeams   ierr = hipMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Hip(ceed, ierr);
10477d8d0e25Snbeams   ierr = hipMemcpy(data->d_interp1d, interp1d, iBytes,
10487d8d0e25Snbeams                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
10497d8d0e25Snbeams 
10507d8d0e25Snbeams   ierr = hipMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Hip(ceed, ierr);
10517d8d0e25Snbeams   ierr = hipMemcpy(data->d_grad1d, grad1d, iBytes,
10527d8d0e25Snbeams                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
10537d8d0e25Snbeams 
10547d8d0e25Snbeams   // Compute collocated gradient and copy to GPU
10557d8d0e25Snbeams   data->d_collograd1d = NULL;
10567d8d0e25Snbeams   if (dim == 3 && Q1d >= P1d) {
10577d8d0e25Snbeams     CeedScalar *collograd1d;
10587d8d0e25Snbeams     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChk(ierr);
10597d8d0e25Snbeams     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChk(ierr);
10607d8d0e25Snbeams     ierr = hipMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
10617d8d0e25Snbeams     CeedChk_Hip(ceed, ierr);
10627d8d0e25Snbeams     ierr = hipMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
10637d8d0e25Snbeams                      hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
10647d8d0e25Snbeams     ierr = CeedFree(&collograd1d); CeedChk(ierr);
10657d8d0e25Snbeams   }
10667d8d0e25Snbeams 
1067*9e31c45bSnbeams   // Set number of threads per block for basis kernels
10687d8d0e25Snbeams   CeedInt ncomp;
10697d8d0e25Snbeams   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
1070*9e31c45bSnbeams   ierr = ComputeBasisThreadBlockSizes(dim, P1d, Q1d, ncomp, data->blksizes);
1071*9e31c45bSnbeams   CeedChk(ierr);
1072*9e31c45bSnbeams 
1073*9e31c45bSnbeams   // Compile basis kernels
1074*9e31c45bSnbeams   ierr = CeedCompileHip(ceed, kernelsShared, &data->module, 11,
10757d8d0e25Snbeams                         "Q1D", Q1d,
10767d8d0e25Snbeams                         "P1D", P1d,
10777d8d0e25Snbeams                         "T1D", CeedIntMax(Q1d, P1d),
10787d8d0e25Snbeams                         "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
10797d8d0e25Snbeams                             Q1d : P1d, dim),
10807d8d0e25Snbeams                         "BASIS_DIM", dim,
10817d8d0e25Snbeams                         "BASIS_NCOMP", ncomp,
10827d8d0e25Snbeams                         "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
1083*9e31c45bSnbeams                         "BASIS_NQPT", CeedIntPow(Q1d, dim),
1084*9e31c45bSnbeams                         "INTERP_BLKSIZE", data->blksizes[0],
1085*9e31c45bSnbeams                         "GRAD_BLKSIZE", data->blksizes[1],
1086*9e31c45bSnbeams                         "WEIGHT_BLKSIZE", data->blksizes[2]
10877d8d0e25Snbeams                        ); CeedChk(ierr);
10887d8d0e25Snbeams   ierr = CeedGetKernelHip(ceed, data->module, "interp", &data->interp);
10897d8d0e25Snbeams   CeedChk(ierr);
10907d8d0e25Snbeams   ierr = CeedGetKernelHip(ceed, data->module, "grad", &data->grad);
10917d8d0e25Snbeams   CeedChk(ierr);
10927d8d0e25Snbeams   ierr = CeedGetKernelHip(ceed, data->module, "weight", &data->weight);
10937d8d0e25Snbeams   CeedChk(ierr);
10947d8d0e25Snbeams 
10957d8d0e25Snbeams   ierr = CeedBasisSetData(basis, data); CeedChk(ierr);
10967d8d0e25Snbeams 
10977d8d0e25Snbeams   // Register backend functions
10987d8d0e25Snbeams   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
10997d8d0e25Snbeams                                 CeedBasisApplyTensor_Hip_shared);
11007d8d0e25Snbeams   CeedChk(ierr);
11017d8d0e25Snbeams   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
11027d8d0e25Snbeams                                 CeedBasisDestroy_Hip_shared); CeedChk(ierr);
11037d8d0e25Snbeams   return 0;
11047d8d0e25Snbeams }
11057d8d0e25Snbeams //------------------------------------------------------------------------------
1106