xref: /libCEED/rust/libceed-sys/c-src/backends/hip-shared/ceed-hip-shared-basis.c (revision 3d576824e8d990e1f48c6609089904bee9170514)
17d8d0e25Snbeams // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
27d8d0e25Snbeams // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
37d8d0e25Snbeams // All Rights reserved. See files LICENSE and NOTICE for details.
47d8d0e25Snbeams //
57d8d0e25Snbeams // This file is part of CEED, a collection of benchmarks, miniapps, software
67d8d0e25Snbeams // libraries and APIs for efficient high-order finite element and spectral
77d8d0e25Snbeams // element discretizations for exascale applications. For more information and
87d8d0e25Snbeams // source code availability see http://github.com/ceed.
97d8d0e25Snbeams //
107d8d0e25Snbeams // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
117d8d0e25Snbeams // a collaborative effort of two U.S. Department of Energy organizations (Office
127d8d0e25Snbeams // of Science and the National Nuclear Security Administration) responsible for
137d8d0e25Snbeams // the planning and preparation of a capable exascale ecosystem, including
147d8d0e25Snbeams // software, applications, hardware, advanced system engineering and early
157d8d0e25Snbeams // testbed platforms, in support of the nation's exascale computing imperative.
167d8d0e25Snbeams 
17*3d576824SJeremy L Thompson #include <ceed.h>
18*3d576824SJeremy L Thompson #include <ceed-backend.h>
19*3d576824SJeremy L Thompson #include <hip/hip_runtime.h>
20*3d576824SJeremy L Thompson #include <stddef.h>
217d8d0e25Snbeams #include "ceed-hip-shared.h"
22*3d576824SJeremy L Thompson #include "../hip/ceed-hip.h"
237d8d0e25Snbeams #include "../hip/ceed-hip-compile.h"
247d8d0e25Snbeams 
257d8d0e25Snbeams //------------------------------------------------------------------------------
267d8d0e25Snbeams // Shared mem kernels
277d8d0e25Snbeams //------------------------------------------------------------------------------
287d8d0e25Snbeams // *INDENT-OFF*
297d8d0e25Snbeams static const char *kernelsShared = QUOTE(
307d8d0e25Snbeams 
317d8d0e25Snbeams //------------------------------------------------------------------------------
327d8d0e25Snbeams // Sum input into output
337d8d0e25Snbeams //------------------------------------------------------------------------------
347d8d0e25Snbeams inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
357d8d0e25Snbeams   for (int i = 0; i < P1D; i++)
367d8d0e25Snbeams     r_V[i] += r_U[i];
377d8d0e25Snbeams }
387d8d0e25Snbeams 
397d8d0e25Snbeams //------------------------------------------------------------------------------
407d8d0e25Snbeams // 1D
417d8d0e25Snbeams //------------------------------------------------------------------------------
427d8d0e25Snbeams 
437d8d0e25Snbeams //------------------------------------------------------------------------------
447d8d0e25Snbeams // Read DoFs
457d8d0e25Snbeams //------------------------------------------------------------------------------
467d8d0e25Snbeams inline __device__ void readDofs1d(const int elem, const int tidx,
477d8d0e25Snbeams                                   const int tidy, const int tidz,const int comp,
487d8d0e25Snbeams                                   const int nelem, const CeedScalar *d_U,
497d8d0e25Snbeams                                   CeedScalar *slice) {
507d8d0e25Snbeams   for (int i = 0; i < P1D; i++)
517d8d0e25Snbeams     slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem];
527d8d0e25Snbeams   for (int i = P1D; i < Q1D; i++)
537d8d0e25Snbeams     slice[i + tidz*T1D] = 0.0;
547d8d0e25Snbeams }
557d8d0e25Snbeams 
567d8d0e25Snbeams //------------------------------------------------------------------------------
577d8d0e25Snbeams // Write DoFs
587d8d0e25Snbeams //------------------------------------------------------------------------------
597d8d0e25Snbeams inline __device__ void writeDofs1d(const int elem, const int tidx,
607d8d0e25Snbeams                                    const int tidy, const int comp,
617d8d0e25Snbeams                                    const int nelem, const CeedScalar &r_V,
627d8d0e25Snbeams                                    CeedScalar *d_V) {
637d8d0e25Snbeams   if (tidx<P1D)
647d8d0e25Snbeams     d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V;
657d8d0e25Snbeams }
667d8d0e25Snbeams 
677d8d0e25Snbeams //------------------------------------------------------------------------------
687d8d0e25Snbeams // Read quadrature point data
697d8d0e25Snbeams //------------------------------------------------------------------------------
707d8d0e25Snbeams inline __device__ void readQuads1d(const int elem, const int tidx,
717d8d0e25Snbeams                                    const int tidy, const int tidz, const int comp,
727d8d0e25Snbeams                                    const int dim, const int nelem,
737d8d0e25Snbeams                                    const CeedScalar *d_U, CeedScalar *slice) {
747d8d0e25Snbeams   for (int i = 0; i < Q1D; i++)
757d8d0e25Snbeams     slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
767d8d0e25Snbeams                             dim*BASIS_NCOMP*nelem*Q1D];
777d8d0e25Snbeams   for (int i = Q1D; i < P1D; i++)
787d8d0e25Snbeams     slice[i + tidz*T1D] = 0.0;
797d8d0e25Snbeams }
807d8d0e25Snbeams 
817d8d0e25Snbeams //------------------------------------------------------------------------------
827d8d0e25Snbeams // Write quadrature point data
837d8d0e25Snbeams //------------------------------------------------------------------------------
847d8d0e25Snbeams inline __device__ void writeQuads1d(const int elem, const int tidx,
857d8d0e25Snbeams                                     const int tidy, const int comp,
867d8d0e25Snbeams                                     const int dim, const int nelem,
877d8d0e25Snbeams                                     const CeedScalar &r_V, CeedScalar *d_V) {
887d8d0e25Snbeams   if (tidx<Q1D)
897d8d0e25Snbeams     d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
907d8d0e25Snbeams }
917d8d0e25Snbeams 
927d8d0e25Snbeams //------------------------------------------------------------------------------
937d8d0e25Snbeams // 1D tensor contraction
947d8d0e25Snbeams //------------------------------------------------------------------------------
957d8d0e25Snbeams inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
967d8d0e25Snbeams                                    const int tidy, const int tidz,
977d8d0e25Snbeams                                    const CeedScalar &U, const CeedScalar *B,
987d8d0e25Snbeams                                    CeedScalar &V) {
997d8d0e25Snbeams   V = 0.0;
1007d8d0e25Snbeams   for (int i = 0; i < P1D; ++i)
1017d8d0e25Snbeams     V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction
1027d8d0e25Snbeams }
1037d8d0e25Snbeams 
1047d8d0e25Snbeams //------------------------------------------------------------------------------
1057d8d0e25Snbeams // 1D transpose tensor contraction
1067d8d0e25Snbeams //------------------------------------------------------------------------------
1077d8d0e25Snbeams inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
1087d8d0e25Snbeams     const int tidy, const int tidz,
1097d8d0e25Snbeams     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
1107d8d0e25Snbeams   V = 0.0;
1117d8d0e25Snbeams   for (int i = 0; i < Q1D; ++i)
1127d8d0e25Snbeams     V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction
1137d8d0e25Snbeams }
1147d8d0e25Snbeams 
1157d8d0e25Snbeams //------------------------------------------------------------------------------
1167d8d0e25Snbeams // 1D interpolate to quadrature points
1177d8d0e25Snbeams //------------------------------------------------------------------------------
1187d8d0e25Snbeams inline __device__ void interp1d(const CeedInt nelem, const int transpose,
1197d8d0e25Snbeams                                 const CeedScalar *c_B,
1207d8d0e25Snbeams                                 const CeedScalar *__restrict__ d_U,
1217d8d0e25Snbeams                                 CeedScalar *__restrict__ d_V,
1227d8d0e25Snbeams                                 CeedScalar *slice) {
1237d8d0e25Snbeams   CeedScalar r_V;
1247d8d0e25Snbeams   CeedScalar r_t;
1257d8d0e25Snbeams 
1267d8d0e25Snbeams   const int tidx = threadIdx.x;
1277d8d0e25Snbeams   const int tidy = threadIdx.y;
1287d8d0e25Snbeams   const int tidz = threadIdx.z;
1297d8d0e25Snbeams 
1307d8d0e25Snbeams 
1317d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
1327d8d0e25Snbeams        elem += gridDim.x*blockDim.z) {
1337d8d0e25Snbeams     for (int comp = 0; comp < BASIS_NCOMP; comp++) {
1347d8d0e25Snbeams       if (!transpose) {
1357d8d0e25Snbeams         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
1367d8d0e25Snbeams         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
1377d8d0e25Snbeams         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
1387d8d0e25Snbeams       } else {
1397d8d0e25Snbeams         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
1407d8d0e25Snbeams         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
1417d8d0e25Snbeams         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
1427d8d0e25Snbeams       }
1437d8d0e25Snbeams     }
1447d8d0e25Snbeams   }
1457d8d0e25Snbeams }
1467d8d0e25Snbeams 
1477d8d0e25Snbeams //------------------------------------------------------------------------------
1487d8d0e25Snbeams // 1D derivatives at quadrature points
1497d8d0e25Snbeams //------------------------------------------------------------------------------
1507d8d0e25Snbeams inline __device__ void grad1d(const CeedInt nelem, const int transpose,
1517d8d0e25Snbeams                               const CeedScalar *c_B, const CeedScalar *c_G,
1527d8d0e25Snbeams                               const CeedScalar *__restrict__ d_U,
1537d8d0e25Snbeams                               CeedScalar *__restrict__ d_V,
1547d8d0e25Snbeams                               CeedScalar *slice) {
1557d8d0e25Snbeams   CeedScalar r_U;
1567d8d0e25Snbeams   CeedScalar r_V;
1577d8d0e25Snbeams 
1587d8d0e25Snbeams   const int tidx = threadIdx.x;
1597d8d0e25Snbeams   const int tidy = threadIdx.y;
1607d8d0e25Snbeams   const int tidz = threadIdx.z;
1617d8d0e25Snbeams   int dim;
1627d8d0e25Snbeams 
1637d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
1647d8d0e25Snbeams        elem += gridDim.x*blockDim.z) {
1657d8d0e25Snbeams     for(int comp = 0; comp < BASIS_NCOMP; comp++) {
1667d8d0e25Snbeams       if (!transpose) {
1677d8d0e25Snbeams         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
1687d8d0e25Snbeams         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
1697d8d0e25Snbeams         dim = 0;
1707d8d0e25Snbeams         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
1717d8d0e25Snbeams       } else {
1727d8d0e25Snbeams         dim = 0;
1737d8d0e25Snbeams         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
1747d8d0e25Snbeams         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
1757d8d0e25Snbeams         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
1767d8d0e25Snbeams       }
1777d8d0e25Snbeams     }
1787d8d0e25Snbeams   }
1797d8d0e25Snbeams }
1807d8d0e25Snbeams 
1817d8d0e25Snbeams //------------------------------------------------------------------------------
1827d8d0e25Snbeams // 1D Quadrature weights
1837d8d0e25Snbeams //------------------------------------------------------------------------------
1847d8d0e25Snbeams __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
1857d8d0e25Snbeams                          CeedScalar *w) {
1867d8d0e25Snbeams   const int tid = threadIdx.x;
1877d8d0e25Snbeams   const CeedScalar weight = qweight1d[tid];
1887d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
1897d8d0e25Snbeams        elem += gridDim.x*blockDim.y) {
1907d8d0e25Snbeams     const int ind = elem*Q1D + tid;
1917d8d0e25Snbeams     w[ind] = weight;
1927d8d0e25Snbeams   }
1937d8d0e25Snbeams }
1947d8d0e25Snbeams 
1957d8d0e25Snbeams //------------------------------------------------------------------------------
1967d8d0e25Snbeams // 2D
1977d8d0e25Snbeams //------------------------------------------------------------------------------
1987d8d0e25Snbeams 
1997d8d0e25Snbeams //------------------------------------------------------------------------------
2007d8d0e25Snbeams // Read DoFs
2017d8d0e25Snbeams //------------------------------------------------------------------------------
2027d8d0e25Snbeams inline __device__ void readDofs2d(const int elem, const int tidx,
2037d8d0e25Snbeams                                   const int tidy, const int comp,
2047d8d0e25Snbeams                                   const int nelem, const CeedScalar *d_U,
2057d8d0e25Snbeams                                   CeedScalar &U) {
2067d8d0e25Snbeams   U = (tidx<P1D && tidy<P1D) ?
2077d8d0e25Snbeams       d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0;
2087d8d0e25Snbeams }
2097d8d0e25Snbeams 
2107d8d0e25Snbeams //------------------------------------------------------------------------------
2117d8d0e25Snbeams // Write DoFs
2127d8d0e25Snbeams //------------------------------------------------------------------------------
2137d8d0e25Snbeams inline __device__ void writeDofs2d(const int elem, const int tidx,
2147d8d0e25Snbeams                                    const int tidy, const int comp,
2157d8d0e25Snbeams                                    const int nelem, const CeedScalar &r_V,
2167d8d0e25Snbeams                                    CeedScalar *d_V) {
2177d8d0e25Snbeams   if (tidx<P1D && tidy<P1D)
2187d8d0e25Snbeams     d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V;
2197d8d0e25Snbeams }
2207d8d0e25Snbeams 
2217d8d0e25Snbeams //------------------------------------------------------------------------------
2227d8d0e25Snbeams // Read quadrature point data
2237d8d0e25Snbeams //------------------------------------------------------------------------------
2247d8d0e25Snbeams inline __device__ void readQuads2d(const int elem, const int tidx,
2257d8d0e25Snbeams                                    const int tidy, const int comp,
2267d8d0e25Snbeams                                    const int dim, const int nelem,
2277d8d0e25Snbeams                                    const CeedScalar *d_U, CeedScalar &U ) {
2287d8d0e25Snbeams   U = (tidx<Q1D && tidy<Q1D) ?
2297d8d0e25Snbeams       d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
2307d8d0e25Snbeams       dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0;
2317d8d0e25Snbeams }
2327d8d0e25Snbeams 
2337d8d0e25Snbeams //------------------------------------------------------------------------------
2347d8d0e25Snbeams // Write quadrature point data
2357d8d0e25Snbeams //------------------------------------------------------------------------------
2367d8d0e25Snbeams inline __device__ void writeQuads2d(const int elem, const int tidx,
2377d8d0e25Snbeams                                     const int tidy, const int comp,
2387d8d0e25Snbeams                                     const int dim, const int nelem,
2397d8d0e25Snbeams                                     const CeedScalar &r_V, CeedScalar *d_V) {
2407d8d0e25Snbeams   if (tidx<Q1D && tidy<Q1D)
2417d8d0e25Snbeams     d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
2427d8d0e25Snbeams     dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V;
2437d8d0e25Snbeams }
2447d8d0e25Snbeams 
2457d8d0e25Snbeams //------------------------------------------------------------------------------
2467d8d0e25Snbeams // 2D tensor contraction x
2477d8d0e25Snbeams //------------------------------------------------------------------------------
2487d8d0e25Snbeams inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
2497d8d0e25Snbeams                                    const int tidy, const int tidz,
2507d8d0e25Snbeams                                    const CeedScalar &U, const CeedScalar *B,
2517d8d0e25Snbeams                                    CeedScalar &V) {
2527d8d0e25Snbeams   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
2537d8d0e25Snbeams   __syncthreads();
2547d8d0e25Snbeams   V = 0.0;
2557d8d0e25Snbeams   if (tidx < Q1D)
2567d8d0e25Snbeams     for (int i = 0; i < P1D; ++i)
2577d8d0e25Snbeams       V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
2587d8d0e25Snbeams   __syncthreads();
2597d8d0e25Snbeams }
2607d8d0e25Snbeams 
2617d8d0e25Snbeams //------------------------------------------------------------------------------
2627d8d0e25Snbeams // 2D tensor contraction y
2637d8d0e25Snbeams //------------------------------------------------------------------------------
2647d8d0e25Snbeams inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
2657d8d0e25Snbeams                                    const int tidy, const int tidz,
2667d8d0e25Snbeams                                    const CeedScalar &U, const CeedScalar *B,
2677d8d0e25Snbeams                                    CeedScalar &V) {
2687d8d0e25Snbeams   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
2697d8d0e25Snbeams   __syncthreads();
2707d8d0e25Snbeams   V = 0.0;
2717d8d0e25Snbeams   if (tidy < Q1D)
2727d8d0e25Snbeams     for (int i = 0; i < P1D; ++i)
2737d8d0e25Snbeams       V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
2747d8d0e25Snbeams   __syncthreads();
2757d8d0e25Snbeams }
2767d8d0e25Snbeams 
2777d8d0e25Snbeams //------------------------------------------------------------------------------
2787d8d0e25Snbeams // 2D transpose tensor contraction y
2797d8d0e25Snbeams //------------------------------------------------------------------------------
2807d8d0e25Snbeams inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
2817d8d0e25Snbeams     const int tidy, const int tidz,
2827d8d0e25Snbeams     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
2837d8d0e25Snbeams   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
2847d8d0e25Snbeams   __syncthreads();
2857d8d0e25Snbeams   V = 0.0;
2867d8d0e25Snbeams   if (tidy < P1D)
2877d8d0e25Snbeams     for (int i = 0; i < Q1D; ++i)
2887d8d0e25Snbeams       V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
2897d8d0e25Snbeams   __syncthreads();
2907d8d0e25Snbeams }
2917d8d0e25Snbeams 
2927d8d0e25Snbeams //------------------------------------------------------------------------------
2937d8d0e25Snbeams // 2D transpose tensor contraction x
2947d8d0e25Snbeams //------------------------------------------------------------------------------
2957d8d0e25Snbeams inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
2967d8d0e25Snbeams     const int tidy, const int tidz,
2977d8d0e25Snbeams     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
2987d8d0e25Snbeams   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
2997d8d0e25Snbeams   __syncthreads();
3007d8d0e25Snbeams   V = 0.0;
3017d8d0e25Snbeams   if (tidx < P1D)
3027d8d0e25Snbeams     for (int i = 0; i < Q1D; ++i)
3037d8d0e25Snbeams       V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
3047d8d0e25Snbeams   __syncthreads();
3057d8d0e25Snbeams }
3067d8d0e25Snbeams 
3077d8d0e25Snbeams //------------------------------------------------------------------------------
3087d8d0e25Snbeams // 2D interpolate to quadrature points
3097d8d0e25Snbeams //------------------------------------------------------------------------------
3107d8d0e25Snbeams inline __device__ void interp2d(const CeedInt nelem, const int transpose,
3117d8d0e25Snbeams                                 const CeedScalar *c_B,
3127d8d0e25Snbeams                                 const CeedScalar *__restrict__ d_U,
3137d8d0e25Snbeams                                 CeedScalar *__restrict__ d_V,
3147d8d0e25Snbeams                                 CeedScalar *slice) {
3157d8d0e25Snbeams   CeedScalar r_V;
3167d8d0e25Snbeams   CeedScalar r_t;
3177d8d0e25Snbeams 
3187d8d0e25Snbeams   const int tidx = threadIdx.x;
3197d8d0e25Snbeams   const int tidy = threadIdx.y;
3207d8d0e25Snbeams   const int tidz = threadIdx.z;
3217d8d0e25Snbeams   const int blockElem = tidz/BASIS_NCOMP;
3227d8d0e25Snbeams   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
3237d8d0e25Snbeams   const int comp = tidz%BASIS_NCOMP;
3247d8d0e25Snbeams 
3257d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
3267d8d0e25Snbeams        elem += gridDim.x*elemsPerBlock) {
3277d8d0e25Snbeams     const int comp = tidz%BASIS_NCOMP;
3287d8d0e25Snbeams     r_V = 0.0;
3297d8d0e25Snbeams     r_t = 0.0;
3307d8d0e25Snbeams     if (!transpose) {
3317d8d0e25Snbeams       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
3327d8d0e25Snbeams       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
3337d8d0e25Snbeams       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
3347d8d0e25Snbeams       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
3357d8d0e25Snbeams     } else {
3367d8d0e25Snbeams       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
3377d8d0e25Snbeams       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
3387d8d0e25Snbeams       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
3397d8d0e25Snbeams       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
3407d8d0e25Snbeams     }
3417d8d0e25Snbeams   }
3427d8d0e25Snbeams }
3437d8d0e25Snbeams 
3447d8d0e25Snbeams //------------------------------------------------------------------------------
3457d8d0e25Snbeams // 2D derivatives at quadrature points
3467d8d0e25Snbeams //------------------------------------------------------------------------------
3477d8d0e25Snbeams inline __device__ void grad2d(const CeedInt nelem, const int transpose,
3487d8d0e25Snbeams                               const CeedScalar *c_B, const CeedScalar *c_G,
3497d8d0e25Snbeams                               const CeedScalar *__restrict__ d_U,
3507d8d0e25Snbeams                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
3517d8d0e25Snbeams   CeedScalar r_U;
3527d8d0e25Snbeams   CeedScalar r_V;
3537d8d0e25Snbeams   CeedScalar r_t;
3547d8d0e25Snbeams 
3557d8d0e25Snbeams   const int tidx = threadIdx.x;
3567d8d0e25Snbeams   const int tidy = threadIdx.y;
3577d8d0e25Snbeams   const int tidz = threadIdx.z;
3587d8d0e25Snbeams   const int blockElem = tidz/BASIS_NCOMP;
3597d8d0e25Snbeams   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
3607d8d0e25Snbeams   const int comp = tidz%BASIS_NCOMP;
3617d8d0e25Snbeams   int dim;
3627d8d0e25Snbeams 
3637d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
3647d8d0e25Snbeams        elem += gridDim.x*elemsPerBlock) {
3657d8d0e25Snbeams     if (!transpose) {
3667d8d0e25Snbeams       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
3677d8d0e25Snbeams       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
3687d8d0e25Snbeams       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
3697d8d0e25Snbeams       dim = 0;
3707d8d0e25Snbeams       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
3717d8d0e25Snbeams       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
3727d8d0e25Snbeams       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
3737d8d0e25Snbeams       dim = 1;
3747d8d0e25Snbeams       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
3757d8d0e25Snbeams     } else {
3767d8d0e25Snbeams       dim = 0;
3777d8d0e25Snbeams       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
3787d8d0e25Snbeams       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
3797d8d0e25Snbeams       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
3807d8d0e25Snbeams       dim = 1;
3817d8d0e25Snbeams       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
3827d8d0e25Snbeams       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
3837d8d0e25Snbeams       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
3847d8d0e25Snbeams       r_V += r_U;
3857d8d0e25Snbeams       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
3867d8d0e25Snbeams     }
3877d8d0e25Snbeams   }
3887d8d0e25Snbeams }
3897d8d0e25Snbeams 
3907d8d0e25Snbeams //------------------------------------------------------------------------------
3917d8d0e25Snbeams // 2D quadrature weights
3927d8d0e25Snbeams //------------------------------------------------------------------------------
3937d8d0e25Snbeams __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
3947d8d0e25Snbeams                          CeedScalar *w) {
3957d8d0e25Snbeams   const int i = threadIdx.x;
3967d8d0e25Snbeams   const int j = threadIdx.y;
3977d8d0e25Snbeams   const CeedScalar weight = qweight1d[i]*qweight1d[j];
3987d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
3997d8d0e25Snbeams        elem += gridDim.x*blockDim.z) {
4007d8d0e25Snbeams     const int ind = elem*Q1D*Q1D + i + j*Q1D;
4017d8d0e25Snbeams     w[ind] = weight;
4027d8d0e25Snbeams   }
4037d8d0e25Snbeams }
4047d8d0e25Snbeams 
4057d8d0e25Snbeams //------------------------------------------------------------------------------
4067d8d0e25Snbeams // 3D
4077d8d0e25Snbeams //------------------------------------------------------------------------------
4087d8d0e25Snbeams 
4097d8d0e25Snbeams //------------------------------------------------------------------------------
4107d8d0e25Snbeams // Read DoFs
4117d8d0e25Snbeams //------------------------------------------------------------------------------
4127d8d0e25Snbeams inline __device__ void readDofs3d(const int elem, const int tidx,
4137d8d0e25Snbeams                                   const int tidy, const int comp,
4147d8d0e25Snbeams                                   const int nelem, const CeedScalar *d_U,
4157d8d0e25Snbeams                                   CeedScalar *r_U) {
4167d8d0e25Snbeams   for (int i = 0; i < P1D; i++)
4177d8d0e25Snbeams     r_U[i] = (tidx < P1D && tidy < P1D) ?
4187d8d0e25Snbeams               d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
4197d8d0e25Snbeams                   comp*P1D*P1D*P1D*nelem] : 0.0;
4207d8d0e25Snbeams   for (int i = P1D; i < Q1D; i++)
4217d8d0e25Snbeams     r_U[i] = 0.0;
4227d8d0e25Snbeams }
4237d8d0e25Snbeams 
4247d8d0e25Snbeams //------------------------------------------------------------------------------
4257d8d0e25Snbeams // Write DoFs
4267d8d0e25Snbeams //------------------------------------------------------------------------------
4277d8d0e25Snbeams inline __device__ void writeDofs3d(const int elem, const int tidx,
4287d8d0e25Snbeams                                    const int tidy, const int comp,
4297d8d0e25Snbeams                                    const int nelem, const CeedScalar *r_V,
4307d8d0e25Snbeams                                    CeedScalar *d_V) {
4317d8d0e25Snbeams   if (tidx < P1D && tidy < P1D) {
4327d8d0e25Snbeams     for (int i = 0; i < P1D; i++)
4337d8d0e25Snbeams       d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
4347d8d0e25Snbeams           comp*P1D*P1D*P1D*nelem] = r_V[i];
4357d8d0e25Snbeams   }
4367d8d0e25Snbeams }
4377d8d0e25Snbeams 
4387d8d0e25Snbeams //------------------------------------------------------------------------------
4397d8d0e25Snbeams // Read quadrature point data
4407d8d0e25Snbeams //------------------------------------------------------------------------------
4417d8d0e25Snbeams inline __device__ void readQuads3d(const int elem, const int tidx,
4427d8d0e25Snbeams                                    const int tidy, const int comp,
4437d8d0e25Snbeams                                    const int dim, const int nelem,
4447d8d0e25Snbeams                                    const CeedScalar *d_U, CeedScalar *r_U) {
4457d8d0e25Snbeams   for (int i = 0; i < Q1D; i++)
4467d8d0e25Snbeams     r_U[i] = (tidx < Q1D && tidy < Q1D) ?
4477d8d0e25Snbeams               d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
4487d8d0e25Snbeams               comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0;
4497d8d0e25Snbeams   for (int i = Q1D; i < P1D; i++)
4507d8d0e25Snbeams     r_U[i] = 0.0;
4517d8d0e25Snbeams }
4527d8d0e25Snbeams 
4537d8d0e25Snbeams //------------------------------------------------------------------------------
4547d8d0e25Snbeams // Write quadrature point data
4557d8d0e25Snbeams //------------------------------------------------------------------------------
4567d8d0e25Snbeams inline __device__ void writeQuads3d(const int elem, const int tidx,
4577d8d0e25Snbeams                                     const int tidy, const int comp,
4587d8d0e25Snbeams                                     const int dim, const int nelem,
4597d8d0e25Snbeams                                     const CeedScalar *r_V, CeedScalar *d_V) {
4607d8d0e25Snbeams   if (tidx < Q1D && tidy < Q1D) {
4617d8d0e25Snbeams     for (int i = 0; i < Q1D; i++)
4627d8d0e25Snbeams       d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
4637d8d0e25Snbeams           dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i];
4647d8d0e25Snbeams   }
4657d8d0e25Snbeams }
4667d8d0e25Snbeams 
4677d8d0e25Snbeams //------------------------------------------------------------------------------
4687d8d0e25Snbeams // 3D tensor contract x
4697d8d0e25Snbeams //------------------------------------------------------------------------------
4707d8d0e25Snbeams inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
4717d8d0e25Snbeams                                    const int tidy, const int tidz,
4727d8d0e25Snbeams                                    const CeedScalar *U,
4737d8d0e25Snbeams                                    const CeedScalar *B,
4747d8d0e25Snbeams                                    CeedScalar *V) {
4757d8d0e25Snbeams   for (int k = 0; k < P1D; ++k) {
4767d8d0e25Snbeams     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
4777d8d0e25Snbeams     __syncthreads();
4787d8d0e25Snbeams     V[k] = 0.0;
4797d8d0e25Snbeams     if (tidx < Q1D && tidy < P1D)
4807d8d0e25Snbeams       for (int i = 0; i < P1D; ++i)
4817d8d0e25Snbeams         V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
4827d8d0e25Snbeams     __syncthreads();
4837d8d0e25Snbeams   }
4847d8d0e25Snbeams }
4857d8d0e25Snbeams 
4867d8d0e25Snbeams //------------------------------------------------------------------------------
4877d8d0e25Snbeams // 3D tensor contract y
4887d8d0e25Snbeams //------------------------------------------------------------------------------
4897d8d0e25Snbeams inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
4907d8d0e25Snbeams                                    const int tidy, const int tidz,
4917d8d0e25Snbeams                                    const CeedScalar *U,
4927d8d0e25Snbeams                                    const CeedScalar *B,
4937d8d0e25Snbeams                                    CeedScalar *V) {
4947d8d0e25Snbeams   for (int k = 0; k < P1D; ++k) {
4957d8d0e25Snbeams     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
4967d8d0e25Snbeams     __syncthreads();
4977d8d0e25Snbeams     V[k] = 0.0;
4987d8d0e25Snbeams     if (tidx < Q1D && tidy < Q1D)
4997d8d0e25Snbeams       for (int i = 0; i < P1D; ++i)
5007d8d0e25Snbeams         V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
5017d8d0e25Snbeams     __syncthreads();
5027d8d0e25Snbeams   }
5037d8d0e25Snbeams }
5047d8d0e25Snbeams 
5057d8d0e25Snbeams //------------------------------------------------------------------------------
5067d8d0e25Snbeams // 3D tensor contract z
5077d8d0e25Snbeams //------------------------------------------------------------------------------
5087d8d0e25Snbeams inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
5097d8d0e25Snbeams                                    const int tidy, const int tidz,
5107d8d0e25Snbeams                                    const CeedScalar *U,
5117d8d0e25Snbeams                                    const CeedScalar *B,
5127d8d0e25Snbeams                                    CeedScalar *V) {
5137d8d0e25Snbeams   for (int k = 0; k < Q1D; ++k) {
5147d8d0e25Snbeams     V[k] = 0.0;
5157d8d0e25Snbeams     if (tidx < Q1D && tidy < Q1D)
5167d8d0e25Snbeams       for (int i = 0; i < P1D; ++i)
5177d8d0e25Snbeams         V[k] += B[i + k*P1D] * U[i]; // Contract z direction
5187d8d0e25Snbeams   }
5197d8d0e25Snbeams   for (int k = Q1D; k < P1D; ++k)
5207d8d0e25Snbeams     V[k] = 0.0;
5217d8d0e25Snbeams }
5227d8d0e25Snbeams 
5237d8d0e25Snbeams //------------------------------------------------------------------------------
5247d8d0e25Snbeams // 3D transpose tensor contract z
5257d8d0e25Snbeams //------------------------------------------------------------------------------
5267d8d0e25Snbeams inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
5277d8d0e25Snbeams                                             const int tidy, const int tidz,
5287d8d0e25Snbeams                                             const CeedScalar *U,
5297d8d0e25Snbeams                                             const CeedScalar *B,
5307d8d0e25Snbeams                                             CeedScalar *V) {
5317d8d0e25Snbeams   for (int k = 0; k < P1D; ++k) {
5327d8d0e25Snbeams     V[k] = 0.0;
5337d8d0e25Snbeams     if (tidx < Q1D && tidy < Q1D)
5347d8d0e25Snbeams       for (int i = 0; i < Q1D; ++i)
5357d8d0e25Snbeams         V[k] += B[k + i*P1D] * U[i]; // Contract z direction
5367d8d0e25Snbeams   }
5377d8d0e25Snbeams   for (int k = P1D; k < Q1D; ++k)
5387d8d0e25Snbeams     V[k] = 0.0;
5397d8d0e25Snbeams }
5407d8d0e25Snbeams 
5417d8d0e25Snbeams //------------------------------------------------------------------------------
5427d8d0e25Snbeams // 3D transpose tensor contract y
5437d8d0e25Snbeams //------------------------------------------------------------------------------
5447d8d0e25Snbeams inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
5457d8d0e25Snbeams                                             const int tidy, const int tidz,
5467d8d0e25Snbeams                                             const CeedScalar *U,
5477d8d0e25Snbeams                                             const CeedScalar *B,
5487d8d0e25Snbeams                                             CeedScalar *V) {
5497d8d0e25Snbeams   for (int k = 0; k < P1D; ++k) {
5507d8d0e25Snbeams     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
5517d8d0e25Snbeams     __syncthreads();
5527d8d0e25Snbeams     V[k] = 0.0;
5537d8d0e25Snbeams     if (tidx < Q1D && tidy < P1D)
5547d8d0e25Snbeams       for (int i = 0; i < Q1D; ++i)
5557d8d0e25Snbeams         V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
5567d8d0e25Snbeams     __syncthreads();
5577d8d0e25Snbeams   }
5587d8d0e25Snbeams }
5597d8d0e25Snbeams 
5607d8d0e25Snbeams //------------------------------------------------------------------------------
5617d8d0e25Snbeams // 3D transpose tensor contract x
5627d8d0e25Snbeams //------------------------------------------------------------------------------
5637d8d0e25Snbeams inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
5647d8d0e25Snbeams                                             const int tidy, const int tidz,
5657d8d0e25Snbeams                                             const CeedScalar *U,
5667d8d0e25Snbeams                                             const CeedScalar *B,
5677d8d0e25Snbeams                                             CeedScalar *V) {
5687d8d0e25Snbeams   for (int k = 0; k < P1D; ++k) {
5697d8d0e25Snbeams     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
5707d8d0e25Snbeams     __syncthreads();
5717d8d0e25Snbeams     V[k] = 0.0;
5727d8d0e25Snbeams     if (tidx < P1D && tidy < P1D)
5737d8d0e25Snbeams       for (int i = 0; i < Q1D; ++i)
5747d8d0e25Snbeams         V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
5757d8d0e25Snbeams     __syncthreads();
5767d8d0e25Snbeams   }
5777d8d0e25Snbeams }
5787d8d0e25Snbeams 
5797d8d0e25Snbeams //------------------------------------------------------------------------------
5807d8d0e25Snbeams // 3D interpolate to quadrature points
5817d8d0e25Snbeams //------------------------------------------------------------------------------
5827d8d0e25Snbeams inline __device__ void interp3d(const CeedInt nelem, const int transpose,
5837d8d0e25Snbeams                                 const CeedScalar *c_B,
5847d8d0e25Snbeams                                 const CeedScalar *__restrict__ d_U,
5857d8d0e25Snbeams                                 CeedScalar *__restrict__ d_V,
5867d8d0e25Snbeams                                 CeedScalar *slice) {
5877d8d0e25Snbeams   CeedScalar r_V[T1D];
5887d8d0e25Snbeams   CeedScalar r_t[T1D];
5897d8d0e25Snbeams 
5907d8d0e25Snbeams   const int tidx = threadIdx.x;
5917d8d0e25Snbeams   const int tidy = threadIdx.y;
5927d8d0e25Snbeams   const int tidz = threadIdx.z;
5937d8d0e25Snbeams   const int blockElem = tidz/BASIS_NCOMP;
5947d8d0e25Snbeams   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
5957d8d0e25Snbeams   const int comp = tidz%BASIS_NCOMP;
5967d8d0e25Snbeams 
5977d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
5987d8d0e25Snbeams        elem += gridDim.x*elemsPerBlock) {
5997d8d0e25Snbeams     for (int i = 0; i < T1D; ++i) {
6007d8d0e25Snbeams       r_V[i] = 0.0;
6017d8d0e25Snbeams       r_t[i] = 0.0;
6027d8d0e25Snbeams     }
6037d8d0e25Snbeams     if (!transpose) {
6047d8d0e25Snbeams       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
6057d8d0e25Snbeams       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
6067d8d0e25Snbeams       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
6077d8d0e25Snbeams       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
6087d8d0e25Snbeams       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
6097d8d0e25Snbeams     } else {
6107d8d0e25Snbeams       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
6117d8d0e25Snbeams       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
6127d8d0e25Snbeams       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
6137d8d0e25Snbeams       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
6147d8d0e25Snbeams       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
6157d8d0e25Snbeams     }
6167d8d0e25Snbeams   }
6177d8d0e25Snbeams }
6187d8d0e25Snbeams 
6197d8d0e25Snbeams //------------------------------------------------------------------------------
6207d8d0e25Snbeams // 3D derivatives at quadrature points
6217d8d0e25Snbeams //------------------------------------------------------------------------------
6227d8d0e25Snbeams inline __device__ void grad3d(const CeedInt nelem, const int transpose,
6237d8d0e25Snbeams                               const CeedScalar *c_B, const CeedScalar *c_G,
6247d8d0e25Snbeams                               const CeedScalar *__restrict__ d_U,
6257d8d0e25Snbeams                               CeedScalar *__restrict__ d_V,
6267d8d0e25Snbeams                               CeedScalar *slice) {
6277d8d0e25Snbeams   // Use P1D for one of these
6287d8d0e25Snbeams   CeedScalar r_U[T1D];
6297d8d0e25Snbeams   CeedScalar r_V[T1D];
6307d8d0e25Snbeams   CeedScalar r_t[T1D];
6317d8d0e25Snbeams 
6327d8d0e25Snbeams   const int tidx = threadIdx.x;
6337d8d0e25Snbeams   const int tidy = threadIdx.y;
6347d8d0e25Snbeams   const int tidz = threadIdx.z;
6357d8d0e25Snbeams   const int blockElem = tidz/BASIS_NCOMP;
6367d8d0e25Snbeams   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
6377d8d0e25Snbeams   const int comp = tidz%BASIS_NCOMP;
6387d8d0e25Snbeams   int dim;
6397d8d0e25Snbeams 
6407d8d0e25Snbeams   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
6417d8d0e25Snbeams        elem += gridDim.x*elemsPerBlock) {
6427d8d0e25Snbeams     for (int i = 0; i < T1D; ++i) {
6437d8d0e25Snbeams       r_U[i] = 0.0;
6447d8d0e25Snbeams       r_V[i] = 0.0;
6457d8d0e25Snbeams       r_t[i] = 0.0;
6467d8d0e25Snbeams     }
6477d8d0e25Snbeams     if (!transpose) {
6487d8d0e25Snbeams       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
6497d8d0e25Snbeams       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
6507d8d0e25Snbeams       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
6517d8d0e25Snbeams       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
6527d8d0e25Snbeams       dim = 0;
6537d8d0e25Snbeams       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
6547d8d0e25Snbeams       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
6557d8d0e25Snbeams       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
6567d8d0e25Snbeams       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
6577d8d0e25Snbeams       dim = 1;
6587d8d0e25Snbeams       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
6597d8d0e25Snbeams       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
6607d8d0e25Snbeams       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
6617d8d0e25Snbeams       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
6627d8d0e25Snbeams       dim = 2;
6637d8d0e25Snbeams       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
6647d8d0e25Snbeams     } else {
6657d8d0e25Snbeams       dim = 0;
6667d8d0e25Snbeams       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
6677d8d0e25Snbeams       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
6687d8d0e25Snbeams       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
6697d8d0e25Snbeams       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
6707d8d0e25Snbeams       dim = 1;
6717d8d0e25Snbeams       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
6727d8d0e25Snbeams       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
6737d8d0e25Snbeams       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
6747d8d0e25Snbeams       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
6757d8d0e25Snbeams       add(r_V, r_t);
6767d8d0e25Snbeams       dim = 2;
6777d8d0e25Snbeams       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
6787d8d0e25Snbeams       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
6797d8d0e25Snbeams       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
6807d8d0e25Snbeams       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
6817d8d0e25Snbeams       add(r_V, r_t);
6827d8d0e25Snbeams       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
6837d8d0e25Snbeams     }
6847d8d0e25Snbeams   }
6857d8d0e25Snbeams }
6867d8d0e25Snbeams 
6877d8d0e25Snbeams //------------------------------------------------------------------------------
6887d8d0e25Snbeams // 3D quadrature weights
6897d8d0e25Snbeams //------------------------------------------------------------------------------
6907d8d0e25Snbeams __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
6917d8d0e25Snbeams                          CeedScalar *w) {
6927d8d0e25Snbeams   const int i = threadIdx.x;
6937d8d0e25Snbeams   const int j = threadIdx.y;
6947d8d0e25Snbeams   const int k = threadIdx.z;
6957d8d0e25Snbeams   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
6967d8d0e25Snbeams   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
6977d8d0e25Snbeams     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
6987d8d0e25Snbeams     w[ind] = weight;
6997d8d0e25Snbeams   }
7007d8d0e25Snbeams }
7017d8d0e25Snbeams 
7027d8d0e25Snbeams 
7037d8d0e25Snbeams //------------------------------------------------------------------------------
7047d8d0e25Snbeams // Basis kernels
7057d8d0e25Snbeams //------------------------------------------------------------------------------
7067d8d0e25Snbeams 
7077d8d0e25Snbeams //------------------------------------------------------------------------------
7087d8d0e25Snbeams // Interp kernel by dim
7097d8d0e25Snbeams //------------------------------------------------------------------------------
7109e31c45bSnbeams extern "C" __launch_bounds__(INTERP_BLKSIZE) __global__ void interp(
7119e31c45bSnbeams                                   const CeedInt nelem, const int transpose,
7127d8d0e25Snbeams                                   const CeedScalar *c_B,
7137d8d0e25Snbeams                                   const CeedScalar *__restrict__ d_U,
7147d8d0e25Snbeams                                   CeedScalar *__restrict__ d_V) {
7157d8d0e25Snbeams   HIP_DYNAMIC_SHARED( double, slice)
7167d8d0e25Snbeams   if (BASIS_DIM == 1) {
7177d8d0e25Snbeams     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
7187d8d0e25Snbeams   } else if (BASIS_DIM == 2) {
7197d8d0e25Snbeams     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
7207d8d0e25Snbeams   } else if (BASIS_DIM == 3) {
7217d8d0e25Snbeams     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
7227d8d0e25Snbeams   }
7237d8d0e25Snbeams }
7247d8d0e25Snbeams 
7257d8d0e25Snbeams //------------------------------------------------------------------------------
7267d8d0e25Snbeams // Grad kernel by dim
7277d8d0e25Snbeams //------------------------------------------------------------------------------
7289e31c45bSnbeams extern "C" __launch_bounds__(GRAD_BLKSIZE) __global__ void grad(const CeedInt nelem,
7299e31c45bSnbeams                                 const int transpose,
7307d8d0e25Snbeams                                 const CeedScalar *c_B, const CeedScalar *c_G,
7317d8d0e25Snbeams                                 const CeedScalar *__restrict__ d_U,
7327d8d0e25Snbeams                                 CeedScalar *__restrict__ d_V) {
7337d8d0e25Snbeams   HIP_DYNAMIC_SHARED( double, slice)
7347d8d0e25Snbeams   if (BASIS_DIM == 1) {
7357d8d0e25Snbeams     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
7367d8d0e25Snbeams   } else if (BASIS_DIM == 2) {
7377d8d0e25Snbeams     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
7387d8d0e25Snbeams   } else if (BASIS_DIM == 3) {
7397d8d0e25Snbeams     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
7407d8d0e25Snbeams   }
7417d8d0e25Snbeams }
7427d8d0e25Snbeams 
7437d8d0e25Snbeams //------------------------------------------------------------------------------
7447d8d0e25Snbeams // Weight kernels by dim
7457d8d0e25Snbeams //------------------------------------------------------------------------------
7469e31c45bSnbeams extern "C" __launch_bounds__(WEIGHT_BLKSIZE) __global__ void weight(const CeedInt nelem,
7477d8d0e25Snbeams                                   const CeedScalar *__restrict__ qweight1d,
7487d8d0e25Snbeams                                   CeedScalar *__restrict__ v) {
7497d8d0e25Snbeams   if (BASIS_DIM == 1) {
7507d8d0e25Snbeams     weight1d(nelem, qweight1d, v);
7517d8d0e25Snbeams   } else if (BASIS_DIM == 2) {
7527d8d0e25Snbeams     weight2d(nelem, qweight1d, v);
7537d8d0e25Snbeams   } else if (BASIS_DIM == 3) {
7547d8d0e25Snbeams     weight3d(nelem, qweight1d, v);
7557d8d0e25Snbeams   }
7567d8d0e25Snbeams }
7577d8d0e25Snbeams 
7587d8d0e25Snbeams );
7597d8d0e25Snbeams // *INDENT-ON*
7607d8d0e25Snbeams 
7617d8d0e25Snbeams //------------------------------------------------------------------------------
7629e31c45bSnbeams // Compute a block size based on required minimum threads
7639e31c45bSnbeams //------------------------------------------------------------------------------
7649e31c45bSnbeams static CeedInt ComputeBlockSizeFromRequirement(const CeedInt required) {
7659e31c45bSnbeams   CeedInt maxSize = 1024;    // Max total threads per block
7669e31c45bSnbeams   CeedInt currentSize = 64;  // Start with one group
7679e31c45bSnbeams 
7689e31c45bSnbeams   while(currentSize < maxSize) {
7699e31c45bSnbeams     if (currentSize > required)
7709e31c45bSnbeams       break;
7719e31c45bSnbeams     else
7729e31c45bSnbeams       currentSize = currentSize * 2;
7739e31c45bSnbeams   }
7749e31c45bSnbeams   return currentSize;
7759e31c45bSnbeams }
7769e31c45bSnbeams 
7779e31c45bSnbeams //------------------------------------------------------------------------------
7789e31c45bSnbeams // Compute required thread block sizes for basis kernels given P, Q, dim, and
7799e31c45bSnbeams // ncomp
7809e31c45bSnbeams //------------------------------------------------------------------------------
7819e31c45bSnbeams static int ComputeBasisThreadBlockSizes(const CeedInt dim, const CeedInt P1d,
7829e31c45bSnbeams                                         const CeedInt Q1d,
7839e31c45bSnbeams                                         const CeedInt ncomp, CeedInt *blksizes) {
7849e31c45bSnbeams 
7859e31c45bSnbeams   // Note that this will use the same block sizes for all dimensions when compiling,
7869e31c45bSnbeams   // but as each basis object is defined for a particular dimension, we will never
7879e31c45bSnbeams   // call any kernels except the ones for the dimension for which we have computed the
7889e31c45bSnbeams   // block sizes.
7899e31c45bSnbeams   const CeedInt thread1d = CeedIntMax(P1d, Q1d);
7909e31c45bSnbeams   switch (dim) {
7919e31c45bSnbeams   case 1: {
7929e31c45bSnbeams     // Interp kernels:
7939e31c45bSnbeams     blksizes[0] = 256;
7949e31c45bSnbeams 
7959e31c45bSnbeams     // Grad kernels:
7969e31c45bSnbeams     blksizes[1] = 256;
7979e31c45bSnbeams 
7989e31c45bSnbeams     // Weight kernels:
7999e31c45bSnbeams     blksizes[2] = 256;
8009e31c45bSnbeams 
8019e31c45bSnbeams   } break;
8029e31c45bSnbeams   case 2: {
8039e31c45bSnbeams     // Interp kernels:
8049e31c45bSnbeams     CeedInt required = thread1d * thread1d * ncomp;
8059e31c45bSnbeams     blksizes[0]  = ComputeBlockSizeFromRequirement(required);
8069e31c45bSnbeams 
8079e31c45bSnbeams     // Grad kernels: currently use same required minimum threads
8089e31c45bSnbeams     blksizes[1]  = ComputeBlockSizeFromRequirement(required);
8099e31c45bSnbeams 
8109e31c45bSnbeams     // Weight kernels:
8119e31c45bSnbeams     required = CeedIntMax(64, Q1d * Q1d);
8129e31c45bSnbeams     blksizes[2]  = ComputeBlockSizeFromRequirement(required);
8139e31c45bSnbeams 
8149e31c45bSnbeams   } break;
8159e31c45bSnbeams   case 3: {
8169e31c45bSnbeams     // Interp kernels:
8179e31c45bSnbeams     CeedInt required = thread1d * thread1d * ncomp;
8189e31c45bSnbeams     blksizes[0]  = ComputeBlockSizeFromRequirement(required);
8199e31c45bSnbeams 
8209e31c45bSnbeams     // Grad kernels: currently use same required minimum threads
8219e31c45bSnbeams     blksizes[1]  = ComputeBlockSizeFromRequirement(required);
8229e31c45bSnbeams 
8239e31c45bSnbeams     // Weight kernels:
8249e31c45bSnbeams     required = Q1d * Q1d * Q1d;
8259e31c45bSnbeams     blksizes[2]  = ComputeBlockSizeFromRequirement(required);
8269e31c45bSnbeams   }
8279e31c45bSnbeams   }
8289e31c45bSnbeams 
8299e31c45bSnbeams   return 0;
8309e31c45bSnbeams }
8319e31c45bSnbeams 
8329e31c45bSnbeams //------------------------------------------------------------------------------
8337d8d0e25Snbeams // Apply basis
8347d8d0e25Snbeams //------------------------------------------------------------------------------
8357d8d0e25Snbeams int CeedBasisApplyTensor_Hip_shared(CeedBasis basis, const CeedInt nelem,
8367d8d0e25Snbeams                                     CeedTransposeMode tmode,
8377d8d0e25Snbeams                                     CeedEvalMode emode, CeedVector u,
8387d8d0e25Snbeams                                     CeedVector v) {
8397d8d0e25Snbeams   int ierr;
8407d8d0e25Snbeams   Ceed ceed;
8417d8d0e25Snbeams   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
8427d8d0e25Snbeams   Ceed_Hip_shared *ceed_Hip;
8437d8d0e25Snbeams   CeedGetData(ceed, &ceed_Hip); CeedChk(ierr);
8447d8d0e25Snbeams   CeedBasis_Hip_shared *data;
8457d8d0e25Snbeams   CeedBasisGetData(basis, &data); CeedChk(ierr);
8467d8d0e25Snbeams   const CeedInt transpose = tmode == CEED_TRANSPOSE;
8477d8d0e25Snbeams   CeedInt dim, ncomp;
8487d8d0e25Snbeams   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
8497d8d0e25Snbeams   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
8507d8d0e25Snbeams 
8517d8d0e25Snbeams   // Read vectors
8527d8d0e25Snbeams   const CeedScalar *d_u;
8537d8d0e25Snbeams   CeedScalar *d_v;
8547d8d0e25Snbeams   if (emode != CEED_EVAL_WEIGHT) {
8557d8d0e25Snbeams     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
8567d8d0e25Snbeams   }
8577d8d0e25Snbeams   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
8587d8d0e25Snbeams 
8597d8d0e25Snbeams   // Clear v for transpose mode
8607d8d0e25Snbeams   if (tmode == CEED_TRANSPOSE) {
8617d8d0e25Snbeams     CeedInt length;
8627d8d0e25Snbeams     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
8637d8d0e25Snbeams     ierr = hipMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
8647d8d0e25Snbeams   }
8657d8d0e25Snbeams 
8667d8d0e25Snbeams   // Apply basis operation
8677d8d0e25Snbeams   switch (emode) {
8687d8d0e25Snbeams   case CEED_EVAL_INTERP: {
8697d8d0e25Snbeams     CeedInt P1d, Q1d;
8709e31c45bSnbeams     CeedInt blksize = data->blksizes[0];
8717d8d0e25Snbeams     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
8727d8d0e25Snbeams     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
8737d8d0e25Snbeams     CeedInt thread1d = CeedIntMax(Q1d, P1d);
8747d8d0e25Snbeams     ierr = CeedHipInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
8757d8d0e25Snbeams     CeedChk(ierr);
8767d8d0e25Snbeams     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
8777d8d0e25Snbeams                           &d_u, &d_v
8787d8d0e25Snbeams                          };
8797d8d0e25Snbeams     if (dim == 1) {
880e7ea6884Snbeams       CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64;
8817d8d0e25Snbeams       elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1;
8827d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8837d8d0e25Snbeams                                              ? 1 : 0 );
8847d8d0e25Snbeams       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
8857d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, 1,
8867d8d0e25Snbeams                                        elemsPerBlock, sharedMem,
8877d8d0e25Snbeams                                        interpargs); CeedChk(ierr);
8887d8d0e25Snbeams     } else if (dim == 2) {
8899e31c45bSnbeams       // Check if required threads is small enough to do multiple elems
8909e31c45bSnbeams       const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1);
8917d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
8927d8d0e25Snbeams                                              ? 1 : 0 );
8937d8d0e25Snbeams       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
8947d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d,
8957d8d0e25Snbeams                                        ncomp*elemsPerBlock, sharedMem,
8967d8d0e25Snbeams                                        interpargs); CeedChk(ierr);
8977d8d0e25Snbeams     } else if (dim == 3) {
8987d8d0e25Snbeams       CeedInt elemsPerBlock = 1;
8997d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
9007d8d0e25Snbeams                                              ? 1 : 0 );
9017d8d0e25Snbeams       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
9027d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d,
9037d8d0e25Snbeams                                        ncomp*elemsPerBlock, sharedMem,
9047d8d0e25Snbeams                                        interpargs); CeedChk(ierr);
9057d8d0e25Snbeams     }
9067d8d0e25Snbeams   } break;
9077d8d0e25Snbeams   case CEED_EVAL_GRAD: {
9087d8d0e25Snbeams     CeedInt P1d, Q1d;
9099e31c45bSnbeams     CeedInt blksize = data->blksizes[1];
9107d8d0e25Snbeams     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
9117d8d0e25Snbeams     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
9127d8d0e25Snbeams     CeedInt thread1d = CeedIntMax(Q1d, P1d);
9137d8d0e25Snbeams     ierr = CeedHipInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
9147d8d0e25Snbeams                                  Q1d, &data->c_B, &data->c_G);
9157d8d0e25Snbeams     CeedChk(ierr);
9167d8d0e25Snbeams     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
9177d8d0e25Snbeams                         &data->c_G, &d_u, &d_v
9187d8d0e25Snbeams                        };
9197d8d0e25Snbeams     if (dim == 1) {
920e7ea6884Snbeams       CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64;
9217d8d0e25Snbeams       elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1;
9227d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
9237d8d0e25Snbeams                                              ? 1 : 0 );
9247d8d0e25Snbeams       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
9257d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, 1,
9267d8d0e25Snbeams                                        elemsPerBlock, sharedMem, gradargs);
9277d8d0e25Snbeams       CeedChk(ierr);
9287d8d0e25Snbeams     } else if (dim == 2) {
9299e31c45bSnbeams       // Check if required threads is small enough to do multiple elems
9309e31c45bSnbeams       const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1);
9317d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
9327d8d0e25Snbeams                                              ? 1 : 0 );
9337d8d0e25Snbeams       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
9347d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d,
9357d8d0e25Snbeams                                        ncomp*elemsPerBlock, sharedMem,
9367d8d0e25Snbeams                                        gradargs); CeedChk(ierr);
9377d8d0e25Snbeams     } else if (dim == 3) {
9387d8d0e25Snbeams       CeedInt elemsPerBlock = 1;
9397d8d0e25Snbeams       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
9407d8d0e25Snbeams                                              ? 1 : 0 );
9417d8d0e25Snbeams       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
9427d8d0e25Snbeams       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d,
9437d8d0e25Snbeams                                        ncomp*elemsPerBlock, sharedMem,
9447d8d0e25Snbeams                                        gradargs); CeedChk(ierr);
9457d8d0e25Snbeams     }
9467d8d0e25Snbeams   } break;
9477d8d0e25Snbeams   case CEED_EVAL_WEIGHT: {
9487d8d0e25Snbeams     CeedInt Q1d;
9499e31c45bSnbeams     CeedInt blksize = data->blksizes[2];
9507d8d0e25Snbeams     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
9517d8d0e25Snbeams     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
9527d8d0e25Snbeams     if (dim == 1) {
9539e31c45bSnbeams       const CeedInt optElems = blksize/Q1d;
9547d8d0e25Snbeams       const CeedInt elemsPerBlock = optElems>0?optElems:1;
9557d8d0e25Snbeams       const CeedInt gridsize = nelem/elemsPerBlock + ( (
9567d8d0e25Snbeams                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
9577d8d0e25Snbeams       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d,
9587d8d0e25Snbeams                                  elemsPerBlock, 1, weightargs);
9597d8d0e25Snbeams       CeedChk(ierr);
9607d8d0e25Snbeams     } else if (dim == 2) {
9619e31c45bSnbeams       const CeedInt optElems = blksize/(Q1d*Q1d);
9627d8d0e25Snbeams       const CeedInt elemsPerBlock = optElems>0?optElems:1;
9637d8d0e25Snbeams       const CeedInt gridsize = nelem/elemsPerBlock + ( (
9647d8d0e25Snbeams                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
9657d8d0e25Snbeams       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d,
9667d8d0e25Snbeams                                  elemsPerBlock, weightargs);
9677d8d0e25Snbeams       CeedChk(ierr);
9687d8d0e25Snbeams     } else if (dim == 3) {
9697d8d0e25Snbeams       const CeedInt gridsize = nelem;
9707d8d0e25Snbeams       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
9717d8d0e25Snbeams                                  weightargs);
9727d8d0e25Snbeams       CeedChk(ierr);
9737d8d0e25Snbeams     }
9747d8d0e25Snbeams   } break;
9757d8d0e25Snbeams   // LCOV_EXCL_START
9767d8d0e25Snbeams   // Evaluate the divergence to/from the quadrature points
9777d8d0e25Snbeams   case CEED_EVAL_DIV:
9787d8d0e25Snbeams     return CeedError(ceed, 1, "CEED_EVAL_DIV not supported");
9797d8d0e25Snbeams   // Evaluate the curl to/from the quadrature points
9807d8d0e25Snbeams   case CEED_EVAL_CURL:
9817d8d0e25Snbeams     return CeedError(ceed, 1, "CEED_EVAL_CURL not supported");
9827d8d0e25Snbeams   // Take no action, BasisApply should not have been called
9837d8d0e25Snbeams   case CEED_EVAL_NONE:
9847d8d0e25Snbeams     return CeedError(ceed, 1,
9857d8d0e25Snbeams                      "CEED_EVAL_NONE does not make sense in this context");
9867d8d0e25Snbeams     // LCOV_EXCL_STOP
9877d8d0e25Snbeams   }
9887d8d0e25Snbeams 
9897d8d0e25Snbeams   // Restore vectors
9907d8d0e25Snbeams   if (emode != CEED_EVAL_WEIGHT) {
9917d8d0e25Snbeams     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
9927d8d0e25Snbeams   }
9937d8d0e25Snbeams   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
9947d8d0e25Snbeams   return 0;
9957d8d0e25Snbeams }
9967d8d0e25Snbeams 
9977d8d0e25Snbeams //------------------------------------------------------------------------------
9987d8d0e25Snbeams // Destroy basis
9997d8d0e25Snbeams //------------------------------------------------------------------------------
10007d8d0e25Snbeams static int CeedBasisDestroy_Hip_shared(CeedBasis basis) {
10017d8d0e25Snbeams   int ierr;
10027d8d0e25Snbeams   Ceed ceed;
10037d8d0e25Snbeams   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
10047d8d0e25Snbeams 
10057d8d0e25Snbeams   CeedBasis_Hip_shared *data;
10067d8d0e25Snbeams   ierr = CeedBasisGetData(basis, &data); CeedChk(ierr);
10077d8d0e25Snbeams 
10087d8d0e25Snbeams   CeedChk_Hip(ceed, hipModuleUnload(data->module));
10097d8d0e25Snbeams 
10107d8d0e25Snbeams   ierr = hipFree(data->d_qweight1d); CeedChk_Hip(ceed, ierr);
10117d8d0e25Snbeams   ierr = hipFree(data->d_interp1d); CeedChk_Hip(ceed, ierr);
10127d8d0e25Snbeams   ierr = hipFree(data->d_grad1d); CeedChk_Hip(ceed, ierr);
10137d8d0e25Snbeams   ierr = hipFree(data->d_collograd1d); CeedChk_Hip(ceed, ierr);
10147d8d0e25Snbeams 
10157d8d0e25Snbeams   ierr = CeedFree(&data); CeedChk(ierr);
10167d8d0e25Snbeams 
10177d8d0e25Snbeams   return 0;
10187d8d0e25Snbeams }
10197d8d0e25Snbeams 
10207d8d0e25Snbeams //------------------------------------------------------------------------------
10217d8d0e25Snbeams // Create tensor basis
10227d8d0e25Snbeams //------------------------------------------------------------------------------
10237d8d0e25Snbeams int CeedBasisCreateTensorH1_Hip_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
10247d8d0e25Snbeams                                        const CeedScalar *interp1d,
10257d8d0e25Snbeams                                        const CeedScalar *grad1d,
10267d8d0e25Snbeams                                        const CeedScalar *qref1d,
10277d8d0e25Snbeams                                        const CeedScalar *qweight1d,
10287d8d0e25Snbeams                                        CeedBasis basis) {
10297d8d0e25Snbeams   int ierr;
10307d8d0e25Snbeams   Ceed ceed;
10317d8d0e25Snbeams   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
10327d8d0e25Snbeams   CeedBasis_Hip_shared *data;
10337d8d0e25Snbeams   ierr = CeedCalloc(1, &data); CeedChk(ierr);
10347d8d0e25Snbeams 
10357d8d0e25Snbeams   // Copy basis data to GPU
10367d8d0e25Snbeams   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
10377d8d0e25Snbeams   ierr = hipMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Hip(ceed, ierr);
10387d8d0e25Snbeams   ierr = hipMemcpy(data->d_qweight1d, qweight1d, qBytes,
10397d8d0e25Snbeams                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
10407d8d0e25Snbeams 
10417d8d0e25Snbeams   const CeedInt iBytes = qBytes * P1d;
10427d8d0e25Snbeams   ierr = hipMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Hip(ceed, ierr);
10437d8d0e25Snbeams   ierr = hipMemcpy(data->d_interp1d, interp1d, iBytes,
10447d8d0e25Snbeams                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
10457d8d0e25Snbeams 
10467d8d0e25Snbeams   ierr = hipMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Hip(ceed, ierr);
10477d8d0e25Snbeams   ierr = hipMemcpy(data->d_grad1d, grad1d, iBytes,
10487d8d0e25Snbeams                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
10497d8d0e25Snbeams 
10507d8d0e25Snbeams   // Compute collocated gradient and copy to GPU
10517d8d0e25Snbeams   data->d_collograd1d = NULL;
10527d8d0e25Snbeams   if (dim == 3 && Q1d >= P1d) {
10537d8d0e25Snbeams     CeedScalar *collograd1d;
10547d8d0e25Snbeams     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChk(ierr);
10557d8d0e25Snbeams     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChk(ierr);
10567d8d0e25Snbeams     ierr = hipMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
10577d8d0e25Snbeams     CeedChk_Hip(ceed, ierr);
10587d8d0e25Snbeams     ierr = hipMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
10597d8d0e25Snbeams                      hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
10607d8d0e25Snbeams     ierr = CeedFree(&collograd1d); CeedChk(ierr);
10617d8d0e25Snbeams   }
10627d8d0e25Snbeams 
10639e31c45bSnbeams   // Set number of threads per block for basis kernels
10647d8d0e25Snbeams   CeedInt ncomp;
10657d8d0e25Snbeams   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
10669e31c45bSnbeams   ierr = ComputeBasisThreadBlockSizes(dim, P1d, Q1d, ncomp, data->blksizes);
10679e31c45bSnbeams   CeedChk(ierr);
10689e31c45bSnbeams 
10699e31c45bSnbeams   // Compile basis kernels
10709e31c45bSnbeams   ierr = CeedCompileHip(ceed, kernelsShared, &data->module, 11,
10717d8d0e25Snbeams                         "Q1D", Q1d,
10727d8d0e25Snbeams                         "P1D", P1d,
10737d8d0e25Snbeams                         "T1D", CeedIntMax(Q1d, P1d),
10747d8d0e25Snbeams                         "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
10757d8d0e25Snbeams                             Q1d : P1d, dim),
10767d8d0e25Snbeams                         "BASIS_DIM", dim,
10777d8d0e25Snbeams                         "BASIS_NCOMP", ncomp,
10787d8d0e25Snbeams                         "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
10799e31c45bSnbeams                         "BASIS_NQPT", CeedIntPow(Q1d, dim),
10809e31c45bSnbeams                         "INTERP_BLKSIZE", data->blksizes[0],
10819e31c45bSnbeams                         "GRAD_BLKSIZE", data->blksizes[1],
10829e31c45bSnbeams                         "WEIGHT_BLKSIZE", data->blksizes[2]
10837d8d0e25Snbeams                        ); CeedChk(ierr);
10847d8d0e25Snbeams   ierr = CeedGetKernelHip(ceed, data->module, "interp", &data->interp);
10857d8d0e25Snbeams   CeedChk(ierr);
10867d8d0e25Snbeams   ierr = CeedGetKernelHip(ceed, data->module, "grad", &data->grad);
10877d8d0e25Snbeams   CeedChk(ierr);
10887d8d0e25Snbeams   ierr = CeedGetKernelHip(ceed, data->module, "weight", &data->weight);
10897d8d0e25Snbeams   CeedChk(ierr);
10907d8d0e25Snbeams 
10917d8d0e25Snbeams   ierr = CeedBasisSetData(basis, data); CeedChk(ierr);
10927d8d0e25Snbeams 
10937d8d0e25Snbeams   // Register backend functions
10947d8d0e25Snbeams   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
10957d8d0e25Snbeams                                 CeedBasisApplyTensor_Hip_shared);
10967d8d0e25Snbeams   CeedChk(ierr);
10977d8d0e25Snbeams   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
10987d8d0e25Snbeams                                 CeedBasisDestroy_Hip_shared); CeedChk(ierr);
10997d8d0e25Snbeams   return 0;
11007d8d0e25Snbeams }
11017d8d0e25Snbeams //------------------------------------------------------------------------------
1102