xref: /libCEED/rust/libceed-sys/c-src/backends/hip-shared/ceed-hip-shared-basis.c (revision 6dbfb411544a7a6bdd33f391c97c69cd9e1f444a)
17d8d0e25Snbeams // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
27d8d0e25Snbeams // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
37d8d0e25Snbeams // All Rights reserved. See files LICENSE and NOTICE for details.
47d8d0e25Snbeams //
57d8d0e25Snbeams // This file is part of CEED, a collection of benchmarks, miniapps, software
67d8d0e25Snbeams // libraries and APIs for efficient high-order finite element and spectral
77d8d0e25Snbeams // element discretizations for exascale applications. For more information and
87d8d0e25Snbeams // source code availability see http://github.com/ceed.
97d8d0e25Snbeams //
107d8d0e25Snbeams // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
117d8d0e25Snbeams // a collaborative effort of two U.S. Department of Energy organizations (Office
127d8d0e25Snbeams // of Science and the National Nuclear Security Administration) responsible for
137d8d0e25Snbeams // the planning and preparation of a capable exascale ecosystem, including
147d8d0e25Snbeams // software, applications, hardware, advanced system engineering and early
157d8d0e25Snbeams // testbed platforms, in support of the nation's exascale computing imperative.
167d8d0e25Snbeams 
17ec3da8bcSJed Brown #include <ceed/ceed.h>
18ec3da8bcSJed Brown #include <ceed/backend.h>
193d576824SJeremy L Thompson #include <hip/hip_runtime.h>
203d576824SJeremy L Thompson #include <stddef.h>
217d8d0e25Snbeams #include "ceed-hip-shared.h"
223d576824SJeremy L Thompson #include "../hip/ceed-hip.h"
237d8d0e25Snbeams #include "../hip/ceed-hip-compile.h"
247d8d0e25Snbeams 
257d8d0e25Snbeams //------------------------------------------------------------------------------
267d8d0e25Snbeams // Shared mem kernels
277d8d0e25Snbeams //------------------------------------------------------------------------------
287d8d0e25Snbeams // *INDENT-OFF*
297d8d0e25Snbeams static const char *kernelsShared = QUOTE(
307d8d0e25Snbeams 
317d8d0e25Snbeams //------------------------------------------------------------------------------
327d8d0e25Snbeams // Sum input into output
337d8d0e25Snbeams //------------------------------------------------------------------------------
347d8d0e25Snbeams inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
357d8d0e25Snbeams   for (int i = 0; i < P1D; i++)
367d8d0e25Snbeams     r_V[i] += r_U[i];
377d8d0e25Snbeams }
387d8d0e25Snbeams 
397d8d0e25Snbeams //------------------------------------------------------------------------------
409dd88646Snbeams // Load matrices for basis actions
419dd88646Snbeams //------------------------------------------------------------------------------
429dd88646Snbeams inline __device__ void loadMatrix(const CeedScalar* d_B, CeedScalar* B) {
439dd88646Snbeams   CeedInt tid = threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.y*blockDim.x;
449dd88646Snbeams   for (CeedInt i = tid; i < P1D*Q1D; i += blockDim.x*blockDim.y*blockDim.z)
459dd88646Snbeams     B[i] = d_B[i];
469dd88646Snbeams }
479dd88646Snbeams 
489dd88646Snbeams //------------------------------------------------------------------------------
497d8d0e25Snbeams // 1D
507d8d0e25Snbeams //------------------------------------------------------------------------------
517d8d0e25Snbeams 
527d8d0e25Snbeams //------------------------------------------------------------------------------
537d8d0e25Snbeams // Read DoFs
547d8d0e25Snbeams //------------------------------------------------------------------------------
557d8d0e25Snbeams inline __device__ void readDofs1d(const int elem, const int tidx,
567d8d0e25Snbeams                                   const int tidy, const int tidz,const int comp,
577d8d0e25Snbeams                                   const int nelem, const CeedScalar *d_U,
587d8d0e25Snbeams                                   CeedScalar *slice) {
597d8d0e25Snbeams   for (int i = 0; i < P1D; i++)
607d8d0e25Snbeams     slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem];
617d8d0e25Snbeams   for (int i = P1D; i < Q1D; i++)
627d8d0e25Snbeams     slice[i + tidz*T1D] = 0.0;
637d8d0e25Snbeams }
647d8d0e25Snbeams 
657d8d0e25Snbeams //------------------------------------------------------------------------------
667d8d0e25Snbeams // Write DoFs
677d8d0e25Snbeams //------------------------------------------------------------------------------
687d8d0e25Snbeams inline __device__ void writeDofs1d(const int elem, const int tidx,
697d8d0e25Snbeams                                    const int tidy, const int comp,
707d8d0e25Snbeams                                    const int nelem, const CeedScalar &r_V,
717d8d0e25Snbeams                                    CeedScalar *d_V) {
727d8d0e25Snbeams   if (tidx<P1D)
737d8d0e25Snbeams     d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V;
747d8d0e25Snbeams }
757d8d0e25Snbeams 
767d8d0e25Snbeams //------------------------------------------------------------------------------
777d8d0e25Snbeams // Read quadrature point data
787d8d0e25Snbeams //------------------------------------------------------------------------------
797d8d0e25Snbeams inline __device__ void readQuads1d(const int elem, const int tidx,
807d8d0e25Snbeams                                    const int tidy, const int tidz, const int comp,
817d8d0e25Snbeams                                    const int dim, const int nelem,
827d8d0e25Snbeams                                    const CeedScalar *d_U, CeedScalar *slice) {
837d8d0e25Snbeams   for (int i = 0; i < Q1D; i++)
847d8d0e25Snbeams     slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
857d8d0e25Snbeams                             dim*BASIS_NCOMP*nelem*Q1D];
867d8d0e25Snbeams   for (int i = Q1D; i < P1D; i++)
877d8d0e25Snbeams     slice[i + tidz*T1D] = 0.0;
887d8d0e25Snbeams }
897d8d0e25Snbeams 
907d8d0e25Snbeams //------------------------------------------------------------------------------
917d8d0e25Snbeams // Write quadrature point data
927d8d0e25Snbeams //------------------------------------------------------------------------------
937d8d0e25Snbeams inline __device__ void writeQuads1d(const int elem, const int tidx,
947d8d0e25Snbeams                                     const int tidy, const int comp,
957d8d0e25Snbeams                                     const int dim, const int nelem,
967d8d0e25Snbeams                                     const CeedScalar &r_V, CeedScalar *d_V) {
977d8d0e25Snbeams   if (tidx<Q1D)
987d8d0e25Snbeams     d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
997d8d0e25Snbeams }
1007d8d0e25Snbeams 
1017d8d0e25Snbeams //------------------------------------------------------------------------------
1027d8d0e25Snbeams // 1D tensor contraction
1037d8d0e25Snbeams //------------------------------------------------------------------------------
1047d8d0e25Snbeams inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
1057d8d0e25Snbeams                                    const int tidy, const int tidz,
1067d8d0e25Snbeams                                    const CeedScalar &U, const CeedScalar *B,
1077d8d0e25Snbeams                                    CeedScalar &V) {
1087d8d0e25Snbeams   V = 0.0;
1097d8d0e25Snbeams   for (int i = 0; i < P1D; ++i)
1107d8d0e25Snbeams     V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction
1117d8d0e25Snbeams }
1127d8d0e25Snbeams 
1137d8d0e25Snbeams //------------------------------------------------------------------------------
1147d8d0e25Snbeams // 1D transpose tensor contraction
1157d8d0e25Snbeams //------------------------------------------------------------------------------
1167d8d0e25Snbeams inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
1177d8d0e25Snbeams     const int tidy, const int tidz,
1187d8d0e25Snbeams     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
1197d8d0e25Snbeams   V = 0.0;
1207d8d0e25Snbeams   for (int i = 0; i < Q1D; ++i)
1217d8d0e25Snbeams     V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction
1227d8d0e25Snbeams }
1237d8d0e25Snbeams 
1247d8d0e25Snbeams //------------------------------------------------------------------------------
1257d8d0e25Snbeams // 1D interpolate to quadrature points
1267d8d0e25Snbeams //------------------------------------------------------------------------------
1277d8d0e25Snbeams inline __device__ void interp1d(const CeedInt nelem, const int transpose,
128e9132427Snbeams                                 const CeedScalar *s_B,
1297d8d0e25Snbeams                                 const CeedScalar *__restrict__ d_U,
1307d8d0e25Snbeams                                 CeedScalar *__restrict__ d_V,
1317d8d0e25Snbeams                                 CeedScalar *slice) {
1327d8d0e25Snbeams   CeedScalar r_V;
1337d8d0e25Snbeams   CeedScalar r_t;
1347d8d0e25Snbeams 
1357d8d0e25Snbeams   const int tidx = threadIdx.x;
1367d8d0e25Snbeams   const int tidy = threadIdx.y;
1377d8d0e25Snbeams   const int tidz = threadIdx.z;
1387d8d0e25Snbeams 
1397d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
1407d8d0e25Snbeams        elem += gridDim.x*blockDim.z) {
1417d8d0e25Snbeams     for (int comp = 0; comp < BASIS_NCOMP; comp++) {
1427d8d0e25Snbeams       if (!transpose) {
1437d8d0e25Snbeams         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
144e9132427Snbeams         ContractX1d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
1457d8d0e25Snbeams         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
1467d8d0e25Snbeams       } else {
1477d8d0e25Snbeams         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
148e9132427Snbeams         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
1497d8d0e25Snbeams         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
1507d8d0e25Snbeams       }
1517d8d0e25Snbeams     }
1527d8d0e25Snbeams   }
1537d8d0e25Snbeams }
1547d8d0e25Snbeams 
1557d8d0e25Snbeams //------------------------------------------------------------------------------
1567d8d0e25Snbeams // 1D derivatives at quadrature points
1577d8d0e25Snbeams //------------------------------------------------------------------------------
1587d8d0e25Snbeams inline __device__ void grad1d(const CeedInt nelem, const int transpose,
159e9132427Snbeams                               const CeedScalar *s_B, const CeedScalar *s_G,
1607d8d0e25Snbeams                               const CeedScalar *__restrict__ d_U,
1617d8d0e25Snbeams                               CeedScalar *__restrict__ d_V,
1627d8d0e25Snbeams                               CeedScalar *slice) {
1637d8d0e25Snbeams   CeedScalar r_U;
1647d8d0e25Snbeams   CeedScalar r_V;
1657d8d0e25Snbeams 
1667d8d0e25Snbeams   const int tidx = threadIdx.x;
1677d8d0e25Snbeams   const int tidy = threadIdx.y;
1687d8d0e25Snbeams   const int tidz = threadIdx.z;
1697d8d0e25Snbeams   int dim;
1707d8d0e25Snbeams 
1717d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
1727d8d0e25Snbeams        elem += gridDim.x*blockDim.z) {
1737d8d0e25Snbeams     for(int comp = 0; comp < BASIS_NCOMP; comp++) {
1747d8d0e25Snbeams       if (!transpose) {
1757d8d0e25Snbeams         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
176e9132427Snbeams         ContractX1d(slice, tidx, tidy, tidz, r_U, s_G, r_V);
1777d8d0e25Snbeams         dim = 0;
1787d8d0e25Snbeams         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
1797d8d0e25Snbeams       } else {
1807d8d0e25Snbeams         dim = 0;
1817d8d0e25Snbeams         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
182e9132427Snbeams         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, s_G, r_V);
1837d8d0e25Snbeams         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
1847d8d0e25Snbeams       }
1857d8d0e25Snbeams     }
1867d8d0e25Snbeams   }
1877d8d0e25Snbeams }
1887d8d0e25Snbeams 
1897d8d0e25Snbeams //------------------------------------------------------------------------------
1907d8d0e25Snbeams // 1D Quadrature weights
1917d8d0e25Snbeams //------------------------------------------------------------------------------
1927d8d0e25Snbeams __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
1937d8d0e25Snbeams                          CeedScalar *w) {
1947d8d0e25Snbeams   const int tid = threadIdx.x;
1957d8d0e25Snbeams   const CeedScalar weight = qweight1d[tid];
1967d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
1977d8d0e25Snbeams        elem += gridDim.x*blockDim.y) {
1987d8d0e25Snbeams     const int ind = elem*Q1D + tid;
1997d8d0e25Snbeams     w[ind] = weight;
2007d8d0e25Snbeams   }
2017d8d0e25Snbeams }
2027d8d0e25Snbeams 
2037d8d0e25Snbeams //------------------------------------------------------------------------------
2047d8d0e25Snbeams // 2D
2057d8d0e25Snbeams //------------------------------------------------------------------------------
2067d8d0e25Snbeams 
2077d8d0e25Snbeams //------------------------------------------------------------------------------
2087d8d0e25Snbeams // Read DoFs
2097d8d0e25Snbeams //------------------------------------------------------------------------------
2107d8d0e25Snbeams inline __device__ void readDofs2d(const int elem, const int tidx,
2117d8d0e25Snbeams                                   const int tidy, const int comp,
2127d8d0e25Snbeams                                   const int nelem, const CeedScalar *d_U,
2137d8d0e25Snbeams                                   CeedScalar &U) {
2147d8d0e25Snbeams   U = (tidx<P1D && tidy<P1D) ?
2157d8d0e25Snbeams       d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0;
2167d8d0e25Snbeams }
2177d8d0e25Snbeams 
2187d8d0e25Snbeams //------------------------------------------------------------------------------
2197d8d0e25Snbeams // Write DoFs
2207d8d0e25Snbeams //------------------------------------------------------------------------------
2217d8d0e25Snbeams inline __device__ void writeDofs2d(const int elem, const int tidx,
2227d8d0e25Snbeams                                    const int tidy, const int comp,
2237d8d0e25Snbeams                                    const int nelem, const CeedScalar &r_V,
2247d8d0e25Snbeams                                    CeedScalar *d_V) {
2257d8d0e25Snbeams   if (tidx<P1D && tidy<P1D)
2267d8d0e25Snbeams     d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V;
2277d8d0e25Snbeams }
2287d8d0e25Snbeams 
2297d8d0e25Snbeams //------------------------------------------------------------------------------
2307d8d0e25Snbeams // Read quadrature point data
2317d8d0e25Snbeams //------------------------------------------------------------------------------
2327d8d0e25Snbeams inline __device__ void readQuads2d(const int elem, const int tidx,
2337d8d0e25Snbeams                                    const int tidy, const int comp,
2347d8d0e25Snbeams                                    const int dim, const int nelem,
2357d8d0e25Snbeams                                    const CeedScalar *d_U, CeedScalar &U ) {
2367d8d0e25Snbeams   U = (tidx<Q1D && tidy<Q1D) ?
2377d8d0e25Snbeams       d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
2387d8d0e25Snbeams       dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0;
2397d8d0e25Snbeams }
2407d8d0e25Snbeams 
2417d8d0e25Snbeams //------------------------------------------------------------------------------
2427d8d0e25Snbeams // Write quadrature point data
2437d8d0e25Snbeams //------------------------------------------------------------------------------
2447d8d0e25Snbeams inline __device__ void writeQuads2d(const int elem, const int tidx,
2457d8d0e25Snbeams                                     const int tidy, const int comp,
2467d8d0e25Snbeams                                     const int dim, const int nelem,
2477d8d0e25Snbeams                                     const CeedScalar &r_V, CeedScalar *d_V) {
2487d8d0e25Snbeams   if (tidx<Q1D && tidy<Q1D)
2497d8d0e25Snbeams     d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
2507d8d0e25Snbeams     dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V;
2517d8d0e25Snbeams }
2527d8d0e25Snbeams 
2537d8d0e25Snbeams //------------------------------------------------------------------------------
2547d8d0e25Snbeams // 2D tensor contraction x
2557d8d0e25Snbeams //------------------------------------------------------------------------------
2567d8d0e25Snbeams inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
2577d8d0e25Snbeams                                    const int tidy, const int tidz,
2587d8d0e25Snbeams                                    const CeedScalar &U, const CeedScalar *B,
2597d8d0e25Snbeams                                    CeedScalar &V) {
2607d8d0e25Snbeams   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
2617d8d0e25Snbeams   __syncthreads();
2627d8d0e25Snbeams   V = 0.0;
2637d8d0e25Snbeams   if (tidx < Q1D)
2647d8d0e25Snbeams     for (int i = 0; i < P1D; ++i)
2657d8d0e25Snbeams       V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
2667d8d0e25Snbeams   __syncthreads();
2677d8d0e25Snbeams }
2687d8d0e25Snbeams 
2697d8d0e25Snbeams //------------------------------------------------------------------------------
2707d8d0e25Snbeams // 2D tensor contraction y
2717d8d0e25Snbeams //------------------------------------------------------------------------------
2727d8d0e25Snbeams inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
2737d8d0e25Snbeams                                    const int tidy, const int tidz,
2747d8d0e25Snbeams                                    const CeedScalar &U, const CeedScalar *B,
2757d8d0e25Snbeams                                    CeedScalar &V) {
2767d8d0e25Snbeams   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
2777d8d0e25Snbeams   __syncthreads();
2787d8d0e25Snbeams   V = 0.0;
2797d8d0e25Snbeams   if (tidy < Q1D)
2807d8d0e25Snbeams     for (int i = 0; i < P1D; ++i)
2817d8d0e25Snbeams       V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
2827d8d0e25Snbeams   __syncthreads();
2837d8d0e25Snbeams }
2847d8d0e25Snbeams 
2857d8d0e25Snbeams //------------------------------------------------------------------------------
2867d8d0e25Snbeams // 2D transpose tensor contraction y
2877d8d0e25Snbeams //------------------------------------------------------------------------------
2887d8d0e25Snbeams inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
2897d8d0e25Snbeams     const int tidy, const int tidz,
2907d8d0e25Snbeams     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
2917d8d0e25Snbeams   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
2927d8d0e25Snbeams   __syncthreads();
2937d8d0e25Snbeams   V = 0.0;
2947d8d0e25Snbeams   if (tidy < P1D)
2957d8d0e25Snbeams     for (int i = 0; i < Q1D; ++i)
2967d8d0e25Snbeams       V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
2977d8d0e25Snbeams   __syncthreads();
2987d8d0e25Snbeams }
2997d8d0e25Snbeams 
3007d8d0e25Snbeams //------------------------------------------------------------------------------
3017d8d0e25Snbeams // 2D transpose tensor contraction x
3027d8d0e25Snbeams //------------------------------------------------------------------------------
3037d8d0e25Snbeams inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
3047d8d0e25Snbeams     const int tidy, const int tidz,
3057d8d0e25Snbeams     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
3067d8d0e25Snbeams   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
3077d8d0e25Snbeams   __syncthreads();
3087d8d0e25Snbeams   V = 0.0;
3097d8d0e25Snbeams   if (tidx < P1D)
3107d8d0e25Snbeams     for (int i = 0; i < Q1D; ++i)
3117d8d0e25Snbeams       V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
3127d8d0e25Snbeams   __syncthreads();
3137d8d0e25Snbeams }
3147d8d0e25Snbeams 
3157d8d0e25Snbeams //------------------------------------------------------------------------------
3167d8d0e25Snbeams // 2D interpolate to quadrature points
3177d8d0e25Snbeams //------------------------------------------------------------------------------
3187d8d0e25Snbeams inline __device__ void interp2d(const CeedInt nelem, const int transpose,
319e9132427Snbeams                                 const CeedScalar *s_B,
3207d8d0e25Snbeams                                 const CeedScalar *__restrict__ d_U,
3217d8d0e25Snbeams                                 CeedScalar *__restrict__ d_V,
3227d8d0e25Snbeams                                 CeedScalar *slice) {
3237d8d0e25Snbeams   CeedScalar r_V;
3247d8d0e25Snbeams   CeedScalar r_t;
3257d8d0e25Snbeams 
3267d8d0e25Snbeams   const int tidx = threadIdx.x;
3277d8d0e25Snbeams   const int tidy = threadIdx.y;
3287d8d0e25Snbeams   const int tidz = threadIdx.z;
3297d8d0e25Snbeams   const int blockElem = tidz/BASIS_NCOMP;
3307d8d0e25Snbeams   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
3317d8d0e25Snbeams   const int comp = tidz%BASIS_NCOMP;
3327d8d0e25Snbeams 
3337d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
3347d8d0e25Snbeams        elem += gridDim.x*elemsPerBlock) {
3357d8d0e25Snbeams     const int comp = tidz%BASIS_NCOMP;
3367d8d0e25Snbeams     r_V = 0.0;
3377d8d0e25Snbeams     r_t = 0.0;
3387d8d0e25Snbeams     if (!transpose) {
3397d8d0e25Snbeams       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
340e9132427Snbeams       ContractX2d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
341e9132427Snbeams       ContractY2d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
3427d8d0e25Snbeams       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
3437d8d0e25Snbeams     } else {
3447d8d0e25Snbeams       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
345e9132427Snbeams       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
346e9132427Snbeams       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
3477d8d0e25Snbeams       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
3487d8d0e25Snbeams     }
3497d8d0e25Snbeams   }
3507d8d0e25Snbeams }
3517d8d0e25Snbeams 
3527d8d0e25Snbeams //------------------------------------------------------------------------------
3537d8d0e25Snbeams // 2D derivatives at quadrature points
3547d8d0e25Snbeams //------------------------------------------------------------------------------
3557d8d0e25Snbeams inline __device__ void grad2d(const CeedInt nelem, const int transpose,
356e9132427Snbeams                               const CeedScalar *s_B, const CeedScalar *s_G,
3577d8d0e25Snbeams                               const CeedScalar *__restrict__ d_U,
3587d8d0e25Snbeams                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
3597d8d0e25Snbeams   CeedScalar r_U;
3607d8d0e25Snbeams   CeedScalar r_V;
3617d8d0e25Snbeams   CeedScalar r_t;
3627d8d0e25Snbeams 
3637d8d0e25Snbeams   const int tidx = threadIdx.x;
3647d8d0e25Snbeams   const int tidy = threadIdx.y;
3657d8d0e25Snbeams   const int tidz = threadIdx.z;
3667d8d0e25Snbeams   const int blockElem = tidz/BASIS_NCOMP;
3677d8d0e25Snbeams   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
3687d8d0e25Snbeams   const int comp = tidz%BASIS_NCOMP;
3697d8d0e25Snbeams   int dim;
3707d8d0e25Snbeams 
3717d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
3727d8d0e25Snbeams        elem += gridDim.x*elemsPerBlock) {
3737d8d0e25Snbeams     if (!transpose) {
3747d8d0e25Snbeams       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
375e9132427Snbeams       ContractX2d(slice, tidx, tidy, tidz, r_U, s_G, r_t);
376e9132427Snbeams       ContractY2d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
3777d8d0e25Snbeams       dim = 0;
3787d8d0e25Snbeams       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
379e9132427Snbeams       ContractX2d(slice, tidx, tidy, tidz, r_U, s_B, r_t);
380e9132427Snbeams       ContractY2d(slice, tidx, tidy, tidz, r_t, s_G, r_V);
3817d8d0e25Snbeams       dim = 1;
3827d8d0e25Snbeams       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
3837d8d0e25Snbeams     } else {
3847d8d0e25Snbeams       dim = 0;
3857d8d0e25Snbeams       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
386e9132427Snbeams       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, s_B, r_t);
387e9132427Snbeams       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, s_G, r_V);
3887d8d0e25Snbeams       dim = 1;
3897d8d0e25Snbeams       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
390e9132427Snbeams       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, s_G, r_t);
391e9132427Snbeams       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, s_B, r_U);
3927d8d0e25Snbeams       r_V += r_U;
3937d8d0e25Snbeams       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
3947d8d0e25Snbeams     }
3957d8d0e25Snbeams   }
3967d8d0e25Snbeams }
3977d8d0e25Snbeams 
3987d8d0e25Snbeams //------------------------------------------------------------------------------
3997d8d0e25Snbeams // 2D quadrature weights
4007d8d0e25Snbeams //------------------------------------------------------------------------------
4017d8d0e25Snbeams __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
4027d8d0e25Snbeams                          CeedScalar *w) {
4037d8d0e25Snbeams   const int i = threadIdx.x;
4047d8d0e25Snbeams   const int j = threadIdx.y;
4057d8d0e25Snbeams   const CeedScalar weight = qweight1d[i]*qweight1d[j];
4067d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
4077d8d0e25Snbeams        elem += gridDim.x*blockDim.z) {
4087d8d0e25Snbeams     const int ind = elem*Q1D*Q1D + i + j*Q1D;
4097d8d0e25Snbeams     w[ind] = weight;
4107d8d0e25Snbeams   }
4117d8d0e25Snbeams }
4127d8d0e25Snbeams 
4137d8d0e25Snbeams //------------------------------------------------------------------------------
4147d8d0e25Snbeams // 3D
4157d8d0e25Snbeams //------------------------------------------------------------------------------
4167d8d0e25Snbeams 
4177d8d0e25Snbeams //------------------------------------------------------------------------------
4187d8d0e25Snbeams // Read DoFs
4197d8d0e25Snbeams //------------------------------------------------------------------------------
4207d8d0e25Snbeams inline __device__ void readDofs3d(const int elem, const int tidx,
4217d8d0e25Snbeams                                   const int tidy, const int comp,
4227d8d0e25Snbeams                                   const int nelem, const CeedScalar *d_U,
4237d8d0e25Snbeams                                   CeedScalar *r_U) {
4247d8d0e25Snbeams   for (int i = 0; i < P1D; i++)
4257d8d0e25Snbeams     r_U[i] = (tidx < P1D && tidy < P1D) ?
4267d8d0e25Snbeams               d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
4277d8d0e25Snbeams                   comp*P1D*P1D*P1D*nelem] : 0.0;
4287d8d0e25Snbeams   for (int i = P1D; i < Q1D; i++)
4297d8d0e25Snbeams     r_U[i] = 0.0;
4307d8d0e25Snbeams }
4317d8d0e25Snbeams 
4327d8d0e25Snbeams //------------------------------------------------------------------------------
4337d8d0e25Snbeams // Write DoFs
4347d8d0e25Snbeams //------------------------------------------------------------------------------
4357d8d0e25Snbeams inline __device__ void writeDofs3d(const int elem, const int tidx,
4367d8d0e25Snbeams                                    const int tidy, const int comp,
4377d8d0e25Snbeams                                    const int nelem, const CeedScalar *r_V,
4387d8d0e25Snbeams                                    CeedScalar *d_V) {
4397d8d0e25Snbeams   if (tidx < P1D && tidy < P1D) {
4407d8d0e25Snbeams     for (int i = 0; i < P1D; i++)
4417d8d0e25Snbeams       d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
4427d8d0e25Snbeams           comp*P1D*P1D*P1D*nelem] = r_V[i];
4437d8d0e25Snbeams   }
4447d8d0e25Snbeams }
4457d8d0e25Snbeams 
4467d8d0e25Snbeams //------------------------------------------------------------------------------
4477d8d0e25Snbeams // Read quadrature point data
4487d8d0e25Snbeams //------------------------------------------------------------------------------
4497d8d0e25Snbeams inline __device__ void readQuads3d(const int elem, const int tidx,
4507d8d0e25Snbeams                                    const int tidy, const int comp,
4517d8d0e25Snbeams                                    const int dim, const int nelem,
4527d8d0e25Snbeams                                    const CeedScalar *d_U, CeedScalar *r_U) {
4537d8d0e25Snbeams   for (int i = 0; i < Q1D; i++)
4547d8d0e25Snbeams     r_U[i] = (tidx < Q1D && tidy < Q1D) ?
4557d8d0e25Snbeams               d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
4567d8d0e25Snbeams               comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0;
4577d8d0e25Snbeams   for (int i = Q1D; i < P1D; i++)
4587d8d0e25Snbeams     r_U[i] = 0.0;
4597d8d0e25Snbeams }
4607d8d0e25Snbeams 
4617d8d0e25Snbeams //------------------------------------------------------------------------------
4627d8d0e25Snbeams // Write quadrature point data
4637d8d0e25Snbeams //------------------------------------------------------------------------------
4647d8d0e25Snbeams inline __device__ void writeQuads3d(const int elem, const int tidx,
4657d8d0e25Snbeams                                     const int tidy, const int comp,
4667d8d0e25Snbeams                                     const int dim, const int nelem,
4677d8d0e25Snbeams                                     const CeedScalar *r_V, CeedScalar *d_V) {
4687d8d0e25Snbeams   if (tidx < Q1D && tidy < Q1D) {
4697d8d0e25Snbeams     for (int i = 0; i < Q1D; i++)
4707d8d0e25Snbeams       d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
4717d8d0e25Snbeams           dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i];
4727d8d0e25Snbeams   }
4737d8d0e25Snbeams }
4747d8d0e25Snbeams 
4757d8d0e25Snbeams //------------------------------------------------------------------------------
4767d8d0e25Snbeams // 3D tensor contract x
4777d8d0e25Snbeams //------------------------------------------------------------------------------
4787d8d0e25Snbeams inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
4797d8d0e25Snbeams                                    const int tidy, const int tidz,
4807d8d0e25Snbeams                                    const CeedScalar *U,
4817d8d0e25Snbeams                                    const CeedScalar *B,
4827d8d0e25Snbeams                                    CeedScalar *V) {
4837d8d0e25Snbeams   for (int k = 0; k < P1D; ++k) {
4847d8d0e25Snbeams     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
4857d8d0e25Snbeams     __syncthreads();
4867d8d0e25Snbeams     V[k] = 0.0;
4877d8d0e25Snbeams     if (tidx < Q1D && tidy < P1D)
4887d8d0e25Snbeams       for (int i = 0; i < P1D; ++i)
4897d8d0e25Snbeams         V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
4907d8d0e25Snbeams     __syncthreads();
4917d8d0e25Snbeams   }
4927d8d0e25Snbeams }
4937d8d0e25Snbeams 
4947d8d0e25Snbeams //------------------------------------------------------------------------------
4957d8d0e25Snbeams // 3D tensor contract y
4967d8d0e25Snbeams //------------------------------------------------------------------------------
4977d8d0e25Snbeams inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
4987d8d0e25Snbeams                                    const int tidy, const int tidz,
4997d8d0e25Snbeams                                    const CeedScalar *U,
5007d8d0e25Snbeams                                    const CeedScalar *B,
5017d8d0e25Snbeams                                    CeedScalar *V) {
5027d8d0e25Snbeams   for (int k = 0; k < P1D; ++k) {
5037d8d0e25Snbeams     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
5047d8d0e25Snbeams     __syncthreads();
5057d8d0e25Snbeams     V[k] = 0.0;
5067d8d0e25Snbeams     if (tidx < Q1D && tidy < Q1D)
5077d8d0e25Snbeams       for (int i = 0; i < P1D; ++i)
5087d8d0e25Snbeams         V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
5097d8d0e25Snbeams     __syncthreads();
5107d8d0e25Snbeams   }
5117d8d0e25Snbeams }
5127d8d0e25Snbeams 
5137d8d0e25Snbeams //------------------------------------------------------------------------------
5147d8d0e25Snbeams // 3D tensor contract z
5157d8d0e25Snbeams //------------------------------------------------------------------------------
5167d8d0e25Snbeams inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
5177d8d0e25Snbeams                                    const int tidy, const int tidz,
5187d8d0e25Snbeams                                    const CeedScalar *U,
5197d8d0e25Snbeams                                    const CeedScalar *B,
5207d8d0e25Snbeams                                    CeedScalar *V) {
5217d8d0e25Snbeams   for (int k = 0; k < Q1D; ++k) {
5227d8d0e25Snbeams     V[k] = 0.0;
5237d8d0e25Snbeams     if (tidx < Q1D && tidy < Q1D)
5247d8d0e25Snbeams       for (int i = 0; i < P1D; ++i)
5257d8d0e25Snbeams         V[k] += B[i + k*P1D] * U[i]; // Contract z direction
5267d8d0e25Snbeams   }
5277d8d0e25Snbeams   for (int k = Q1D; k < P1D; ++k)
5287d8d0e25Snbeams     V[k] = 0.0;
5297d8d0e25Snbeams }
5307d8d0e25Snbeams 
5317d8d0e25Snbeams //------------------------------------------------------------------------------
5327d8d0e25Snbeams // 3D transpose tensor contract z
5337d8d0e25Snbeams //------------------------------------------------------------------------------
5347d8d0e25Snbeams inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
5357d8d0e25Snbeams                                             const int tidy, const int tidz,
5367d8d0e25Snbeams                                             const CeedScalar *U,
5377d8d0e25Snbeams                                             const CeedScalar *B,
5387d8d0e25Snbeams                                             CeedScalar *V) {
5397d8d0e25Snbeams   for (int k = 0; k < P1D; ++k) {
5407d8d0e25Snbeams     V[k] = 0.0;
5417d8d0e25Snbeams     if (tidx < Q1D && tidy < Q1D)
5427d8d0e25Snbeams       for (int i = 0; i < Q1D; ++i)
5437d8d0e25Snbeams         V[k] += B[k + i*P1D] * U[i]; // Contract z direction
5447d8d0e25Snbeams   }
5457d8d0e25Snbeams   for (int k = P1D; k < Q1D; ++k)
5467d8d0e25Snbeams     V[k] = 0.0;
5477d8d0e25Snbeams }
5487d8d0e25Snbeams 
5497d8d0e25Snbeams //------------------------------------------------------------------------------
5507d8d0e25Snbeams // 3D transpose tensor contract y
5517d8d0e25Snbeams //------------------------------------------------------------------------------
5527d8d0e25Snbeams inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
5537d8d0e25Snbeams                                             const int tidy, const int tidz,
5547d8d0e25Snbeams                                             const CeedScalar *U,
5557d8d0e25Snbeams                                             const CeedScalar *B,
5567d8d0e25Snbeams                                             CeedScalar *V) {
5577d8d0e25Snbeams   for (int k = 0; k < P1D; ++k) {
5587d8d0e25Snbeams     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
5597d8d0e25Snbeams     __syncthreads();
5607d8d0e25Snbeams     V[k] = 0.0;
5617d8d0e25Snbeams     if (tidx < Q1D && tidy < P1D)
5627d8d0e25Snbeams       for (int i = 0; i < Q1D; ++i)
5637d8d0e25Snbeams         V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
5647d8d0e25Snbeams     __syncthreads();
5657d8d0e25Snbeams   }
5667d8d0e25Snbeams }
5677d8d0e25Snbeams 
5687d8d0e25Snbeams //------------------------------------------------------------------------------
5697d8d0e25Snbeams // 3D transpose tensor contract x
5707d8d0e25Snbeams //------------------------------------------------------------------------------
5717d8d0e25Snbeams inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
5727d8d0e25Snbeams                                             const int tidy, const int tidz,
5737d8d0e25Snbeams                                             const CeedScalar *U,
5747d8d0e25Snbeams                                             const CeedScalar *B,
5757d8d0e25Snbeams                                             CeedScalar *V) {
5767d8d0e25Snbeams   for (int k = 0; k < P1D; ++k) {
5777d8d0e25Snbeams     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
5787d8d0e25Snbeams     __syncthreads();
5797d8d0e25Snbeams     V[k] = 0.0;
5807d8d0e25Snbeams     if (tidx < P1D && tidy < P1D)
5817d8d0e25Snbeams       for (int i = 0; i < Q1D; ++i)
5827d8d0e25Snbeams         V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
5837d8d0e25Snbeams     __syncthreads();
5847d8d0e25Snbeams   }
5857d8d0e25Snbeams }
5867d8d0e25Snbeams 
5877d8d0e25Snbeams //------------------------------------------------------------------------------
5887d8d0e25Snbeams // 3D interpolate to quadrature points
5897d8d0e25Snbeams //------------------------------------------------------------------------------
5907d8d0e25Snbeams inline __device__ void interp3d(const CeedInt nelem, const int transpose,
591e9132427Snbeams                                 const CeedScalar *s_B,
5927d8d0e25Snbeams                                 const CeedScalar *__restrict__ d_U,
5937d8d0e25Snbeams                                 CeedScalar *__restrict__ d_V,
5947d8d0e25Snbeams                                 CeedScalar *slice) {
5957d8d0e25Snbeams   CeedScalar r_V[T1D];
5967d8d0e25Snbeams   CeedScalar r_t[T1D];
5977d8d0e25Snbeams 
5987d8d0e25Snbeams   const int tidx = threadIdx.x;
5997d8d0e25Snbeams   const int tidy = threadIdx.y;
6007d8d0e25Snbeams   const int tidz = threadIdx.z;
6017d8d0e25Snbeams   const int blockElem = tidz/BASIS_NCOMP;
6027d8d0e25Snbeams   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
6037d8d0e25Snbeams   const int comp = tidz%BASIS_NCOMP;
6047d8d0e25Snbeams 
6057d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
6067d8d0e25Snbeams        elem += gridDim.x*elemsPerBlock) {
6077d8d0e25Snbeams     for (int i = 0; i < T1D; ++i) {
6087d8d0e25Snbeams       r_V[i] = 0.0;
6097d8d0e25Snbeams       r_t[i] = 0.0;
6107d8d0e25Snbeams     }
6117d8d0e25Snbeams     if (!transpose) {
6127d8d0e25Snbeams       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
613e9132427Snbeams       ContractX3d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
614e9132427Snbeams       ContractY3d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
615e9132427Snbeams       ContractZ3d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
6167d8d0e25Snbeams       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
6177d8d0e25Snbeams     } else {
6187d8d0e25Snbeams       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
619e9132427Snbeams       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
620e9132427Snbeams       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
621e9132427Snbeams       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
6227d8d0e25Snbeams       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
6237d8d0e25Snbeams     }
6247d8d0e25Snbeams   }
6257d8d0e25Snbeams }
6267d8d0e25Snbeams 
6277d8d0e25Snbeams //------------------------------------------------------------------------------
6287d8d0e25Snbeams // 3D derivatives at quadrature points
6297d8d0e25Snbeams //------------------------------------------------------------------------------
6307d8d0e25Snbeams inline __device__ void grad3d(const CeedInt nelem, const int transpose,
631e9132427Snbeams                               const CeedScalar *s_B, const CeedScalar *s_G,
6327d8d0e25Snbeams                               const CeedScalar *__restrict__ d_U,
6337d8d0e25Snbeams                               CeedScalar *__restrict__ d_V,
6347d8d0e25Snbeams                               CeedScalar *slice) {
6357d8d0e25Snbeams   // Use P1D for one of these
6367d8d0e25Snbeams   CeedScalar r_U[T1D];
6377d8d0e25Snbeams   CeedScalar r_V[T1D];
6387d8d0e25Snbeams   CeedScalar r_t[T1D];
6397d8d0e25Snbeams 
6407d8d0e25Snbeams   const int tidx = threadIdx.x;
6417d8d0e25Snbeams   const int tidy = threadIdx.y;
6427d8d0e25Snbeams   const int tidz = threadIdx.z;
6437d8d0e25Snbeams   const int blockElem = tidz/BASIS_NCOMP;
6447d8d0e25Snbeams   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
6457d8d0e25Snbeams   const int comp = tidz%BASIS_NCOMP;
6467d8d0e25Snbeams   int dim;
6477d8d0e25Snbeams 
6487d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
6497d8d0e25Snbeams        elem += gridDim.x*elemsPerBlock) {
6507d8d0e25Snbeams     for (int i = 0; i < T1D; ++i) {
6517d8d0e25Snbeams       r_U[i] = 0.0;
6527d8d0e25Snbeams       r_V[i] = 0.0;
6537d8d0e25Snbeams       r_t[i] = 0.0;
6547d8d0e25Snbeams     }
6557d8d0e25Snbeams     if (!transpose) {
6567d8d0e25Snbeams       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
657e9132427Snbeams       ContractX3d(slice, tidx, tidy, tidz, r_U, s_G, r_V);
658e9132427Snbeams       ContractY3d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
659e9132427Snbeams       ContractZ3d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
6607d8d0e25Snbeams       dim = 0;
6617d8d0e25Snbeams       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
662e9132427Snbeams       ContractX3d(slice, tidx, tidy, tidz, r_U, s_B, r_V);
663e9132427Snbeams       ContractY3d(slice, tidx, tidy, tidz, r_V, s_G, r_t);
664e9132427Snbeams       ContractZ3d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
6657d8d0e25Snbeams       dim = 1;
6667d8d0e25Snbeams       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
667e9132427Snbeams       ContractX3d(slice, tidx, tidy, tidz, r_U, s_B, r_V);
668e9132427Snbeams       ContractY3d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
669e9132427Snbeams       ContractZ3d(slice, tidx, tidy, tidz, r_t, s_G, r_V);
6707d8d0e25Snbeams       dim = 2;
6717d8d0e25Snbeams       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
6727d8d0e25Snbeams     } else {
6737d8d0e25Snbeams       dim = 0;
6747d8d0e25Snbeams       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
675e9132427Snbeams       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, s_B, r_t);
676e9132427Snbeams       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_B, r_U);
677e9132427Snbeams       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, s_G, r_V);
6787d8d0e25Snbeams       dim = 1;
6797d8d0e25Snbeams       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
680e9132427Snbeams       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, s_B, r_t);
681e9132427Snbeams       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_G, r_U);
682e9132427Snbeams       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, s_B, r_t);
6837d8d0e25Snbeams       add(r_V, r_t);
6847d8d0e25Snbeams       dim = 2;
6857d8d0e25Snbeams       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
686e9132427Snbeams       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, s_G, r_t);
687e9132427Snbeams       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_B, r_U);
688e9132427Snbeams       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, s_B, r_t);
6897d8d0e25Snbeams       add(r_V, r_t);
6907d8d0e25Snbeams       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
6917d8d0e25Snbeams     }
6927d8d0e25Snbeams   }
6937d8d0e25Snbeams }
6947d8d0e25Snbeams 
6957d8d0e25Snbeams //------------------------------------------------------------------------------
6967d8d0e25Snbeams // 3D quadrature weights
6977d8d0e25Snbeams //------------------------------------------------------------------------------
6987d8d0e25Snbeams __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
6997d8d0e25Snbeams                          CeedScalar *w) {
7007d8d0e25Snbeams   const int i = threadIdx.x;
7017d8d0e25Snbeams   const int j = threadIdx.y;
7027d8d0e25Snbeams   const int k = threadIdx.z;
7037d8d0e25Snbeams   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
7047d8d0e25Snbeams   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
7057d8d0e25Snbeams     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
7067d8d0e25Snbeams     w[ind] = weight;
7077d8d0e25Snbeams   }
7087d8d0e25Snbeams }
7097d8d0e25Snbeams 
7107d8d0e25Snbeams //------------------------------------------------------------------------------
7117d8d0e25Snbeams // Basis kernels
7127d8d0e25Snbeams //------------------------------------------------------------------------------
7137d8d0e25Snbeams 
7147d8d0e25Snbeams //------------------------------------------------------------------------------
7157d8d0e25Snbeams // Interp kernel by dim
7167d8d0e25Snbeams //------------------------------------------------------------------------------
7179e31c45bSnbeams extern "C" __launch_bounds__(INTERP_BLKSIZE) __global__ void interp(
7189e31c45bSnbeams                                   const CeedInt nelem, const int transpose,
7199dd88646Snbeams                                   CeedScalar *d_interp1d,
7207d8d0e25Snbeams                                   const CeedScalar *__restrict__ d_U,
7217d8d0e25Snbeams                                   CeedScalar *__restrict__ d_V) {
7229dd88646Snbeams 
7237d8d0e25Snbeams   HIP_DYNAMIC_SHARED( double, slice)
7249dd88646Snbeams   // load interp1d into shared memory
725e9132427Snbeams   __shared__ double s_B[P1D*Q1D];
726e9132427Snbeams   loadMatrix(d_interp1d, s_B);
72755250509Snbeams   __syncthreads();
7289dd88646Snbeams 
7297d8d0e25Snbeams   if (BASIS_DIM == 1) {
730e9132427Snbeams     interp1d(nelem, transpose, s_B, d_U, d_V, slice);
7317d8d0e25Snbeams   } else if (BASIS_DIM == 2) {
732e9132427Snbeams     interp2d(nelem, transpose, s_B, d_U, d_V, slice);
7337d8d0e25Snbeams   } else if (BASIS_DIM == 3) {
734e9132427Snbeams     interp3d(nelem, transpose, s_B, d_U, d_V, slice);
7357d8d0e25Snbeams   }
7367d8d0e25Snbeams }
7377d8d0e25Snbeams 
7387d8d0e25Snbeams //------------------------------------------------------------------------------
7397d8d0e25Snbeams // Grad kernel by dim
7407d8d0e25Snbeams //------------------------------------------------------------------------------
7419e31c45bSnbeams extern "C" __launch_bounds__(GRAD_BLKSIZE) __global__ void grad(const CeedInt nelem,
7429e31c45bSnbeams                                 const int transpose,
7439dd88646Snbeams                                 CeedScalar *d_interp1d, CeedScalar *d_grad1d,
7447d8d0e25Snbeams                                 const CeedScalar *__restrict__ d_U,
7457d8d0e25Snbeams                                 CeedScalar *__restrict__ d_V) {
7467d8d0e25Snbeams   HIP_DYNAMIC_SHARED( double, slice)
7479dd88646Snbeams   // load interp1d and grad1d into shared memory
748e9132427Snbeams   __shared__ double s_B[P1D*Q1D];
749e9132427Snbeams   loadMatrix(d_interp1d, s_B);
750e9132427Snbeams   __shared__ double s_G[P1D*Q1D];
751e9132427Snbeams   loadMatrix(d_grad1d, s_G);
75255250509Snbeams   __syncthreads();
7539dd88646Snbeams 
7547d8d0e25Snbeams   if (BASIS_DIM == 1) {
755e9132427Snbeams     grad1d(nelem, transpose, s_B, s_G, d_U, d_V, slice);
7567d8d0e25Snbeams   } else if (BASIS_DIM == 2) {
757e9132427Snbeams     grad2d(nelem, transpose, s_B, s_G, d_U, d_V, slice);
7587d8d0e25Snbeams   } else if (BASIS_DIM == 3) {
759e9132427Snbeams     grad3d(nelem, transpose, s_B, s_G, d_U, d_V, slice);
7607d8d0e25Snbeams   }
7617d8d0e25Snbeams }
7627d8d0e25Snbeams 
7637d8d0e25Snbeams //------------------------------------------------------------------------------
7647d8d0e25Snbeams // Weight kernels by dim
7657d8d0e25Snbeams //------------------------------------------------------------------------------
7669e31c45bSnbeams extern "C" __launch_bounds__(WEIGHT_BLKSIZE) __global__ void weight(const CeedInt nelem,
7677d8d0e25Snbeams                                   const CeedScalar *__restrict__ qweight1d,
7687d8d0e25Snbeams                                   CeedScalar *__restrict__ v) {
7697d8d0e25Snbeams   if (BASIS_DIM == 1) {
7707d8d0e25Snbeams     weight1d(nelem, qweight1d, v);
7717d8d0e25Snbeams   } else if (BASIS_DIM == 2) {
7727d8d0e25Snbeams     weight2d(nelem, qweight1d, v);
7737d8d0e25Snbeams   } else if (BASIS_DIM == 3) {
7747d8d0e25Snbeams     weight3d(nelem, qweight1d, v);
7757d8d0e25Snbeams   }
7767d8d0e25Snbeams }
7777d8d0e25Snbeams 
7787d8d0e25Snbeams );
7797d8d0e25Snbeams // *INDENT-ON*
7807d8d0e25Snbeams 
7817d8d0e25Snbeams //------------------------------------------------------------------------------
7829e31c45bSnbeams // Compute a block size based on required minimum threads
7839e31c45bSnbeams //------------------------------------------------------------------------------
7849e31c45bSnbeams static CeedInt ComputeBlockSizeFromRequirement(const CeedInt required) {
7859e31c45bSnbeams   CeedInt maxSize = 1024;    // Max total threads per block
7869e31c45bSnbeams   CeedInt currentSize = 64;  // Start with one group
7879e31c45bSnbeams 
7889e31c45bSnbeams   while(currentSize < maxSize) {
7899e31c45bSnbeams     if (currentSize > required)
7909e31c45bSnbeams       break;
7919e31c45bSnbeams     else
7929e31c45bSnbeams       currentSize = currentSize * 2;
7939e31c45bSnbeams   }
7949e31c45bSnbeams   return currentSize;
7959e31c45bSnbeams }
7969e31c45bSnbeams 
7979e31c45bSnbeams //------------------------------------------------------------------------------
7989e31c45bSnbeams // Compute required thread block sizes for basis kernels given P, Q, dim, and
7999e31c45bSnbeams // ncomp
8009e31c45bSnbeams //------------------------------------------------------------------------------
8019e31c45bSnbeams static int ComputeBasisThreadBlockSizes(const CeedInt dim, const CeedInt P1d,
8029e31c45bSnbeams                                         const CeedInt Q1d,
8039e31c45bSnbeams                                         const CeedInt ncomp, CeedInt *blksizes) {
8049e31c45bSnbeams 
8059e31c45bSnbeams   // Note that this will use the same block sizes for all dimensions when compiling,
8069e31c45bSnbeams   // but as each basis object is defined for a particular dimension, we will never
8079e31c45bSnbeams   // call any kernels except the ones for the dimension for which we have computed the
8089e31c45bSnbeams   // block sizes.
8099e31c45bSnbeams   const CeedInt thread1d = CeedIntMax(P1d, Q1d);
8109e31c45bSnbeams   switch (dim) {
8119e31c45bSnbeams   case 1: {
8129e31c45bSnbeams     // Interp kernels:
8139e31c45bSnbeams     blksizes[0] = 256;
8149e31c45bSnbeams 
8159e31c45bSnbeams     // Grad kernels:
8169e31c45bSnbeams     blksizes[1] = 256;
8179e31c45bSnbeams 
8189e31c45bSnbeams     // Weight kernels:
8199e31c45bSnbeams     blksizes[2] = 256;
8209e31c45bSnbeams 
8219e31c45bSnbeams   } break;
8229e31c45bSnbeams   case 2: {
8239e31c45bSnbeams     // Interp kernels:
8249e31c45bSnbeams     CeedInt required = thread1d * thread1d * ncomp;
8259e31c45bSnbeams     blksizes[0]  = ComputeBlockSizeFromRequirement(required);
8269e31c45bSnbeams 
8279e31c45bSnbeams     // Grad kernels: currently use same required minimum threads
8289e31c45bSnbeams     blksizes[1]  = ComputeBlockSizeFromRequirement(required);
8299e31c45bSnbeams 
8309e31c45bSnbeams     // Weight kernels:
8319e31c45bSnbeams     required = CeedIntMax(64, Q1d * Q1d);
8329e31c45bSnbeams     blksizes[2]  = ComputeBlockSizeFromRequirement(required);
8339e31c45bSnbeams 
8349e31c45bSnbeams   } break;
8359e31c45bSnbeams   case 3: {
8369e31c45bSnbeams     // Interp kernels:
8379e31c45bSnbeams     CeedInt required = thread1d * thread1d * ncomp;
8389e31c45bSnbeams     blksizes[0]  = ComputeBlockSizeFromRequirement(required);
8399e31c45bSnbeams 
8409e31c45bSnbeams     // Grad kernels: currently use same required minimum threads
8419e31c45bSnbeams     blksizes[1]  = ComputeBlockSizeFromRequirement(required);
8429e31c45bSnbeams 
8439e31c45bSnbeams     // Weight kernels:
8449e31c45bSnbeams     required = Q1d * Q1d * Q1d;
8459e31c45bSnbeams     blksizes[2]  = ComputeBlockSizeFromRequirement(required);
8469e31c45bSnbeams   }
8479e31c45bSnbeams   }
8489e31c45bSnbeams 
849e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
8509e31c45bSnbeams }
8519e31c45bSnbeams 
8529e31c45bSnbeams //------------------------------------------------------------------------------
8537d8d0e25Snbeams // Apply basis
8547d8d0e25Snbeams //------------------------------------------------------------------------------
8557d8d0e25Snbeams int CeedBasisApplyTensor_Hip_shared(CeedBasis basis, const CeedInt nelem,
8567d8d0e25Snbeams                                     CeedTransposeMode tmode,
8577d8d0e25Snbeams                                     CeedEvalMode emode, CeedVector u,
8587d8d0e25Snbeams                                     CeedVector v) {
8597d8d0e25Snbeams   int ierr;
8607d8d0e25Snbeams   Ceed ceed;
861e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
862*6dbfb411Snbeams   Ceed_Hip *ceed_Hip;
863e15f9bd0SJeremy L Thompson   CeedGetData(ceed, &ceed_Hip); CeedChkBackend(ierr);
8647d8d0e25Snbeams   CeedBasis_Hip_shared *data;
865e15f9bd0SJeremy L Thompson   CeedBasisGetData(basis, &data); CeedChkBackend(ierr);
8667d8d0e25Snbeams   const CeedInt transpose = tmode == CEED_TRANSPOSE;
8677d8d0e25Snbeams   CeedInt dim, ncomp;
868e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr);
869e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr);
8707d8d0e25Snbeams 
8717d8d0e25Snbeams   // Read vectors
8727d8d0e25Snbeams   const CeedScalar *d_u;
8737d8d0e25Snbeams   CeedScalar *d_v;
8747d8d0e25Snbeams   if (emode != CEED_EVAL_WEIGHT) {
875e15f9bd0SJeremy L Thompson     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChkBackend(ierr);
8767d8d0e25Snbeams   }
877e15f9bd0SJeremy L Thompson   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr);
8787d8d0e25Snbeams 
8797d8d0e25Snbeams   // Clear v for transpose mode
8807d8d0e25Snbeams   if (tmode == CEED_TRANSPOSE) {
8817d8d0e25Snbeams     CeedInt length;
882e15f9bd0SJeremy L Thompson     ierr = CeedVectorGetLength(v, &length); CeedChkBackend(ierr);
883e15f9bd0SJeremy L Thompson     ierr = hipMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChkBackend(ierr);
8847d8d0e25Snbeams   }
8857d8d0e25Snbeams 
8867d8d0e25Snbeams   // Apply basis operation
8877d8d0e25Snbeams   switch (emode) {
8887d8d0e25Snbeams   case CEED_EVAL_INTERP: {
8897d8d0e25Snbeams     CeedInt P1d, Q1d;
8909e31c45bSnbeams     CeedInt blksize = data->blksizes[0];
891e15f9bd0SJeremy L Thompson     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr);
892e15f9bd0SJeremy L Thompson     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr);
8937d8d0e25Snbeams     CeedInt thread1d = CeedIntMax(Q1d, P1d);
8949dd88646Snbeams     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->d_interp1d,
8957d8d0e25Snbeams                           &d_u, &d_v
8967d8d0e25Snbeams                          };
8977d8d0e25Snbeams     if (dim == 1) {
898e7ea6884Snbeams       CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64;
8997d8d0e25Snbeams       elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1;
9007d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
9017d8d0e25Snbeams                                              ? 1 : 0 );
9027d8d0e25Snbeams       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
9037d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, 1,
9047d8d0e25Snbeams                                        elemsPerBlock, sharedMem,
905e15f9bd0SJeremy L Thompson                                        interpargs); CeedChkBackend(ierr);
9067d8d0e25Snbeams     } else if (dim == 2) {
9079e31c45bSnbeams       // Check if required threads is small enough to do multiple elems
9089e31c45bSnbeams       const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1);
9097d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
9107d8d0e25Snbeams                                              ? 1 : 0 );
9117d8d0e25Snbeams       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
9127d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d,
9137d8d0e25Snbeams                                        ncomp*elemsPerBlock, sharedMem,
914e15f9bd0SJeremy L Thompson                                        interpargs); CeedChkBackend(ierr);
9157d8d0e25Snbeams     } else if (dim == 3) {
9167d8d0e25Snbeams       CeedInt elemsPerBlock = 1;
9177d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
9187d8d0e25Snbeams                                              ? 1 : 0 );
9197d8d0e25Snbeams       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
9207d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d,
9217d8d0e25Snbeams                                        ncomp*elemsPerBlock, sharedMem,
922e15f9bd0SJeremy L Thompson                                        interpargs); CeedChkBackend(ierr);
9237d8d0e25Snbeams     }
9247d8d0e25Snbeams   } break;
9257d8d0e25Snbeams   case CEED_EVAL_GRAD: {
9267d8d0e25Snbeams     CeedInt P1d, Q1d;
9279e31c45bSnbeams     CeedInt blksize = data->blksizes[1];
928e15f9bd0SJeremy L Thompson     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr);
929e15f9bd0SJeremy L Thompson     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr);
9307d8d0e25Snbeams     CeedInt thread1d = CeedIntMax(Q1d, P1d);
9319dd88646Snbeams     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->d_interp1d,
9329dd88646Snbeams                         &data->d_grad1d, &d_u, &d_v
9337d8d0e25Snbeams                        };
9347d8d0e25Snbeams     if (dim == 1) {
935e7ea6884Snbeams       CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64;
9367d8d0e25Snbeams       elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1;
9377d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
9387d8d0e25Snbeams                                              ? 1 : 0 );
9397d8d0e25Snbeams       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
9407d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, 1,
9417d8d0e25Snbeams                                        elemsPerBlock, sharedMem, gradargs);
942e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
9437d8d0e25Snbeams     } else if (dim == 2) {
9449e31c45bSnbeams       // Check if required threads is small enough to do multiple elems
9459e31c45bSnbeams       const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1);
9467d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
9477d8d0e25Snbeams                                              ? 1 : 0 );
9487d8d0e25Snbeams       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
9497d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d,
9507d8d0e25Snbeams                                        ncomp*elemsPerBlock, sharedMem,
951e15f9bd0SJeremy L Thompson                                        gradargs); CeedChkBackend(ierr);
9527d8d0e25Snbeams     } else if (dim == 3) {
9537d8d0e25Snbeams       CeedInt elemsPerBlock = 1;
9547d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
9557d8d0e25Snbeams                                              ? 1 : 0 );
9567d8d0e25Snbeams       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
9577d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d,
9587d8d0e25Snbeams                                        ncomp*elemsPerBlock, sharedMem,
959e15f9bd0SJeremy L Thompson                                        gradargs); CeedChkBackend(ierr);
9607d8d0e25Snbeams     }
9617d8d0e25Snbeams   } break;
9627d8d0e25Snbeams   case CEED_EVAL_WEIGHT: {
9637d8d0e25Snbeams     CeedInt Q1d;
9649e31c45bSnbeams     CeedInt blksize = data->blksizes[2];
965e15f9bd0SJeremy L Thompson     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr);
9667d8d0e25Snbeams     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
9677d8d0e25Snbeams     if (dim == 1) {
9689e31c45bSnbeams       const CeedInt optElems = blksize/Q1d;
9697d8d0e25Snbeams       const CeedInt elemsPerBlock = optElems>0?optElems:1;
9707d8d0e25Snbeams       const CeedInt gridsize = nelem/elemsPerBlock + ( (
9717d8d0e25Snbeams                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
9727d8d0e25Snbeams       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d,
9737d8d0e25Snbeams                                  elemsPerBlock, 1, weightargs);
974e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
9757d8d0e25Snbeams     } else if (dim == 2) {
9769e31c45bSnbeams       const CeedInt optElems = blksize/(Q1d*Q1d);
9777d8d0e25Snbeams       const CeedInt elemsPerBlock = optElems>0?optElems:1;
9787d8d0e25Snbeams       const CeedInt gridsize = nelem/elemsPerBlock + ( (
9797d8d0e25Snbeams                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
9807d8d0e25Snbeams       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d,
9817d8d0e25Snbeams                                  elemsPerBlock, weightargs);
982e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
9837d8d0e25Snbeams     } else if (dim == 3) {
9847d8d0e25Snbeams       const CeedInt gridsize = nelem;
9857d8d0e25Snbeams       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
9867d8d0e25Snbeams                                  weightargs);
987e15f9bd0SJeremy L Thompson       CeedChkBackend(ierr);
9887d8d0e25Snbeams     }
9897d8d0e25Snbeams   } break;
9907d8d0e25Snbeams   // LCOV_EXCL_START
9917d8d0e25Snbeams   // Evaluate the divergence to/from the quadrature points
9927d8d0e25Snbeams   case CEED_EVAL_DIV:
993e15f9bd0SJeremy L Thompson     return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
9947d8d0e25Snbeams   // Evaluate the curl to/from the quadrature points
9957d8d0e25Snbeams   case CEED_EVAL_CURL:
996e15f9bd0SJeremy L Thompson     return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
9977d8d0e25Snbeams   // Take no action, BasisApply should not have been called
9987d8d0e25Snbeams   case CEED_EVAL_NONE:
999e15f9bd0SJeremy L Thompson     return CeedError(ceed, CEED_ERROR_BACKEND,
10007d8d0e25Snbeams                      "CEED_EVAL_NONE does not make sense in this context");
10017d8d0e25Snbeams     // LCOV_EXCL_STOP
10027d8d0e25Snbeams   }
10037d8d0e25Snbeams 
10047d8d0e25Snbeams   // Restore vectors
10057d8d0e25Snbeams   if (emode != CEED_EVAL_WEIGHT) {
1006e15f9bd0SJeremy L Thompson     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChkBackend(ierr);
10077d8d0e25Snbeams   }
1008e15f9bd0SJeremy L Thompson   ierr = CeedVectorRestoreArray(v, &d_v); CeedChkBackend(ierr);
1009e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
10107d8d0e25Snbeams }
10117d8d0e25Snbeams 
10127d8d0e25Snbeams //------------------------------------------------------------------------------
10137d8d0e25Snbeams // Destroy basis
10147d8d0e25Snbeams //------------------------------------------------------------------------------
10157d8d0e25Snbeams static int CeedBasisDestroy_Hip_shared(CeedBasis basis) {
10167d8d0e25Snbeams   int ierr;
10177d8d0e25Snbeams   Ceed ceed;
1018e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
10197d8d0e25Snbeams 
10207d8d0e25Snbeams   CeedBasis_Hip_shared *data;
1021e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetData(basis, &data); CeedChkBackend(ierr);
10227d8d0e25Snbeams 
10237d8d0e25Snbeams   CeedChk_Hip(ceed, hipModuleUnload(data->module));
10247d8d0e25Snbeams 
10257d8d0e25Snbeams   ierr = hipFree(data->d_qweight1d); CeedChk_Hip(ceed, ierr);
10267d8d0e25Snbeams   ierr = hipFree(data->d_interp1d); CeedChk_Hip(ceed, ierr);
10277d8d0e25Snbeams   ierr = hipFree(data->d_grad1d); CeedChk_Hip(ceed, ierr);
10287d8d0e25Snbeams   ierr = hipFree(data->d_collograd1d); CeedChk_Hip(ceed, ierr);
10297d8d0e25Snbeams 
1030e15f9bd0SJeremy L Thompson   ierr = CeedFree(&data); CeedChkBackend(ierr);
10317d8d0e25Snbeams 
1032e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
10337d8d0e25Snbeams }
10347d8d0e25Snbeams 
10357d8d0e25Snbeams //------------------------------------------------------------------------------
10367d8d0e25Snbeams // Create tensor basis
10377d8d0e25Snbeams //------------------------------------------------------------------------------
10387d8d0e25Snbeams int CeedBasisCreateTensorH1_Hip_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
10397d8d0e25Snbeams                                        const CeedScalar *interp1d,
10407d8d0e25Snbeams                                        const CeedScalar *grad1d,
10417d8d0e25Snbeams                                        const CeedScalar *qref1d,
10427d8d0e25Snbeams                                        const CeedScalar *qweight1d,
10437d8d0e25Snbeams                                        CeedBasis basis) {
10447d8d0e25Snbeams   int ierr;
10457d8d0e25Snbeams   Ceed ceed;
1046e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
10477d8d0e25Snbeams   CeedBasis_Hip_shared *data;
1048e15f9bd0SJeremy L Thompson   ierr = CeedCalloc(1, &data); CeedChkBackend(ierr);
10497d8d0e25Snbeams 
10507d8d0e25Snbeams   // Copy basis data to GPU
10517d8d0e25Snbeams   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
10527d8d0e25Snbeams   ierr = hipMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Hip(ceed, ierr);
10537d8d0e25Snbeams   ierr = hipMemcpy(data->d_qweight1d, qweight1d, qBytes,
10547d8d0e25Snbeams                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
10557d8d0e25Snbeams 
10567d8d0e25Snbeams   const CeedInt iBytes = qBytes * P1d;
10577d8d0e25Snbeams   ierr = hipMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Hip(ceed, ierr);
10587d8d0e25Snbeams   ierr = hipMemcpy(data->d_interp1d, interp1d, iBytes,
10597d8d0e25Snbeams                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
10607d8d0e25Snbeams 
10617d8d0e25Snbeams   ierr = hipMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Hip(ceed, ierr);
10627d8d0e25Snbeams   ierr = hipMemcpy(data->d_grad1d, grad1d, iBytes,
10637d8d0e25Snbeams                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
10647d8d0e25Snbeams 
10657d8d0e25Snbeams   // Compute collocated gradient and copy to GPU
10667d8d0e25Snbeams   data->d_collograd1d = NULL;
10677d8d0e25Snbeams   if (dim == 3 && Q1d >= P1d) {
10687d8d0e25Snbeams     CeedScalar *collograd1d;
1069e15f9bd0SJeremy L Thompson     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChkBackend(ierr);
1070e15f9bd0SJeremy L Thompson     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChkBackend(ierr);
10717d8d0e25Snbeams     ierr = hipMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
10727d8d0e25Snbeams     CeedChk_Hip(ceed, ierr);
10737d8d0e25Snbeams     ierr = hipMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
10747d8d0e25Snbeams                      hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1075e15f9bd0SJeremy L Thompson     ierr = CeedFree(&collograd1d); CeedChkBackend(ierr);
10767d8d0e25Snbeams   }
10777d8d0e25Snbeams 
10789e31c45bSnbeams   // Set number of threads per block for basis kernels
10797d8d0e25Snbeams   CeedInt ncomp;
1080e15f9bd0SJeremy L Thompson   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr);
10819e31c45bSnbeams   ierr = ComputeBasisThreadBlockSizes(dim, P1d, Q1d, ncomp, data->blksizes);
1082e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
10839e31c45bSnbeams 
10849e31c45bSnbeams   // Compile basis kernels
10859e31c45bSnbeams   ierr = CeedCompileHip(ceed, kernelsShared, &data->module, 11,
10867d8d0e25Snbeams                         "Q1D", Q1d,
10877d8d0e25Snbeams                         "P1D", P1d,
10887d8d0e25Snbeams                         "T1D", CeedIntMax(Q1d, P1d),
10897d8d0e25Snbeams                         "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
10907d8d0e25Snbeams                             Q1d : P1d, dim),
10917d8d0e25Snbeams                         "BASIS_DIM", dim,
10927d8d0e25Snbeams                         "BASIS_NCOMP", ncomp,
10937d8d0e25Snbeams                         "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
10949e31c45bSnbeams                         "BASIS_NQPT", CeedIntPow(Q1d, dim),
10959e31c45bSnbeams                         "INTERP_BLKSIZE", data->blksizes[0],
10969e31c45bSnbeams                         "GRAD_BLKSIZE", data->blksizes[1],
10979e31c45bSnbeams                         "WEIGHT_BLKSIZE", data->blksizes[2]
1098e15f9bd0SJeremy L Thompson                        ); CeedChkBackend(ierr);
10997d8d0e25Snbeams   ierr = CeedGetKernelHip(ceed, data->module, "interp", &data->interp);
1100e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
11017d8d0e25Snbeams   ierr = CeedGetKernelHip(ceed, data->module, "grad", &data->grad);
1102e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
11037d8d0e25Snbeams   ierr = CeedGetKernelHip(ceed, data->module, "weight", &data->weight);
1104e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
11057d8d0e25Snbeams 
1106e15f9bd0SJeremy L Thompson   ierr = CeedBasisSetData(basis, data); CeedChkBackend(ierr);
11077d8d0e25Snbeams 
11087d8d0e25Snbeams   // Register backend functions
11097d8d0e25Snbeams   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
11107d8d0e25Snbeams                                 CeedBasisApplyTensor_Hip_shared);
1111e15f9bd0SJeremy L Thompson   CeedChkBackend(ierr);
11127d8d0e25Snbeams   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
1113e15f9bd0SJeremy L Thompson                                 CeedBasisDestroy_Hip_shared); CeedChkBackend(ierr);
1114e15f9bd0SJeremy L Thompson   return CEED_ERROR_SUCCESS;
11157d8d0e25Snbeams }
11167d8d0e25Snbeams //------------------------------------------------------------------------------
1117