17d8d0e25Snbeams // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC. 27d8d0e25Snbeams // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707. 37d8d0e25Snbeams // All Rights reserved. See files LICENSE and NOTICE for details. 47d8d0e25Snbeams // 57d8d0e25Snbeams // This file is part of CEED, a collection of benchmarks, miniapps, software 67d8d0e25Snbeams // libraries and APIs for efficient high-order finite element and spectral 77d8d0e25Snbeams // element discretizations for exascale applications. For more information and 87d8d0e25Snbeams // source code availability see http://github.com/ceed. 97d8d0e25Snbeams // 107d8d0e25Snbeams // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 117d8d0e25Snbeams // a collaborative effort of two U.S. Department of Energy organizations (Office 127d8d0e25Snbeams // of Science and the National Nuclear Security Administration) responsible for 137d8d0e25Snbeams // the planning and preparation of a capable exascale ecosystem, including 147d8d0e25Snbeams // software, applications, hardware, advanced system engineering and early 157d8d0e25Snbeams // testbed platforms, in support of the nation's exascale computing imperative. 167d8d0e25Snbeams 177d8d0e25Snbeams #include "ceed-hip-shared.h" 187d8d0e25Snbeams #include "../hip/ceed-hip-compile.h" 197d8d0e25Snbeams 207d8d0e25Snbeams //------------------------------------------------------------------------------ 217d8d0e25Snbeams // Shared mem kernels 227d8d0e25Snbeams //------------------------------------------------------------------------------ 237d8d0e25Snbeams // *INDENT-OFF* 247d8d0e25Snbeams static const char *kernelsShared = QUOTE( 257d8d0e25Snbeams 267d8d0e25Snbeams //------------------------------------------------------------------------------ 277d8d0e25Snbeams // Sum input into output 287d8d0e25Snbeams //------------------------------------------------------------------------------ 297d8d0e25Snbeams inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) { 307d8d0e25Snbeams for (int i = 0; i < P1D; i++) 317d8d0e25Snbeams r_V[i] += r_U[i]; 327d8d0e25Snbeams } 337d8d0e25Snbeams 347d8d0e25Snbeams //------------------------------------------------------------------------------ 357d8d0e25Snbeams // 1D 367d8d0e25Snbeams //------------------------------------------------------------------------------ 377d8d0e25Snbeams 387d8d0e25Snbeams //------------------------------------------------------------------------------ 397d8d0e25Snbeams // Read DoFs 407d8d0e25Snbeams //------------------------------------------------------------------------------ 417d8d0e25Snbeams inline __device__ void readDofs1d(const int elem, const int tidx, 427d8d0e25Snbeams const int tidy, const int tidz,const int comp, 437d8d0e25Snbeams const int nelem, const CeedScalar *d_U, 447d8d0e25Snbeams CeedScalar *slice) { 457d8d0e25Snbeams for (int i = 0; i < P1D; i++) 467d8d0e25Snbeams slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem]; 477d8d0e25Snbeams for (int i = P1D; i < Q1D; i++) 487d8d0e25Snbeams slice[i + tidz*T1D] = 0.0; 497d8d0e25Snbeams } 507d8d0e25Snbeams 517d8d0e25Snbeams //------------------------------------------------------------------------------ 527d8d0e25Snbeams // Write DoFs 537d8d0e25Snbeams //------------------------------------------------------------------------------ 547d8d0e25Snbeams inline __device__ void writeDofs1d(const int elem, const int tidx, 557d8d0e25Snbeams const int tidy, const int comp, 567d8d0e25Snbeams const int nelem, const CeedScalar &r_V, 577d8d0e25Snbeams CeedScalar *d_V) { 587d8d0e25Snbeams if (tidx<P1D) 597d8d0e25Snbeams d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V; 607d8d0e25Snbeams } 617d8d0e25Snbeams 627d8d0e25Snbeams //------------------------------------------------------------------------------ 637d8d0e25Snbeams // Read quadrature point data 647d8d0e25Snbeams //------------------------------------------------------------------------------ 657d8d0e25Snbeams inline __device__ void readQuads1d(const int elem, const int tidx, 667d8d0e25Snbeams const int tidy, const int tidz, const int comp, 677d8d0e25Snbeams const int dim, const int nelem, 687d8d0e25Snbeams const CeedScalar *d_U, CeedScalar *slice) { 697d8d0e25Snbeams for (int i = 0; i < Q1D; i++) 707d8d0e25Snbeams slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem + 717d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D]; 727d8d0e25Snbeams for (int i = Q1D; i < P1D; i++) 737d8d0e25Snbeams slice[i + tidz*T1D] = 0.0; 747d8d0e25Snbeams } 757d8d0e25Snbeams 767d8d0e25Snbeams //------------------------------------------------------------------------------ 777d8d0e25Snbeams // Write quadrature point data 787d8d0e25Snbeams //------------------------------------------------------------------------------ 797d8d0e25Snbeams inline __device__ void writeQuads1d(const int elem, const int tidx, 807d8d0e25Snbeams const int tidy, const int comp, 817d8d0e25Snbeams const int dim, const int nelem, 827d8d0e25Snbeams const CeedScalar &r_V, CeedScalar *d_V) { 837d8d0e25Snbeams if (tidx<Q1D) 847d8d0e25Snbeams d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V; 857d8d0e25Snbeams } 867d8d0e25Snbeams 877d8d0e25Snbeams //------------------------------------------------------------------------------ 887d8d0e25Snbeams // 1D tensor contraction 897d8d0e25Snbeams //------------------------------------------------------------------------------ 907d8d0e25Snbeams inline __device__ void ContractX1d(CeedScalar *slice, const int tidx, 917d8d0e25Snbeams const int tidy, const int tidz, 927d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, 937d8d0e25Snbeams CeedScalar &V) { 947d8d0e25Snbeams V = 0.0; 957d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 967d8d0e25Snbeams V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction 977d8d0e25Snbeams } 987d8d0e25Snbeams 997d8d0e25Snbeams //------------------------------------------------------------------------------ 1007d8d0e25Snbeams // 1D transpose tensor contraction 1017d8d0e25Snbeams //------------------------------------------------------------------------------ 1027d8d0e25Snbeams inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx, 1037d8d0e25Snbeams const int tidy, const int tidz, 1047d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 1057d8d0e25Snbeams V = 0.0; 1067d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 1077d8d0e25Snbeams V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction 1087d8d0e25Snbeams } 1097d8d0e25Snbeams 1107d8d0e25Snbeams //------------------------------------------------------------------------------ 1117d8d0e25Snbeams // 1D interpolate to quadrature points 1127d8d0e25Snbeams //------------------------------------------------------------------------------ 1137d8d0e25Snbeams inline __device__ void interp1d(const CeedInt nelem, const int transpose, 1147d8d0e25Snbeams const CeedScalar *c_B, 1157d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 1167d8d0e25Snbeams CeedScalar *__restrict__ d_V, 1177d8d0e25Snbeams CeedScalar *slice) { 1187d8d0e25Snbeams CeedScalar r_V; 1197d8d0e25Snbeams CeedScalar r_t; 1207d8d0e25Snbeams 1217d8d0e25Snbeams const int tidx = threadIdx.x; 1227d8d0e25Snbeams const int tidy = threadIdx.y; 1237d8d0e25Snbeams const int tidz = threadIdx.z; 1247d8d0e25Snbeams 1257d8d0e25Snbeams 1267d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 1277d8d0e25Snbeams elem += gridDim.x*blockDim.z) { 1287d8d0e25Snbeams for (int comp = 0; comp < BASIS_NCOMP; comp++) { 1297d8d0e25Snbeams if (!transpose) { 1307d8d0e25Snbeams readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 1317d8d0e25Snbeams ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 1327d8d0e25Snbeams writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 1337d8d0e25Snbeams } else { 1347d8d0e25Snbeams readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice); 1357d8d0e25Snbeams ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 1367d8d0e25Snbeams writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 1377d8d0e25Snbeams } 1387d8d0e25Snbeams } 1397d8d0e25Snbeams } 1407d8d0e25Snbeams } 1417d8d0e25Snbeams 1427d8d0e25Snbeams //------------------------------------------------------------------------------ 1437d8d0e25Snbeams // 1D derivatives at quadrature points 1447d8d0e25Snbeams //------------------------------------------------------------------------------ 1457d8d0e25Snbeams inline __device__ void grad1d(const CeedInt nelem, const int transpose, 1467d8d0e25Snbeams const CeedScalar *c_B, const CeedScalar *c_G, 1477d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 1487d8d0e25Snbeams CeedScalar *__restrict__ d_V, 1497d8d0e25Snbeams CeedScalar *slice) { 1507d8d0e25Snbeams CeedScalar r_U; 1517d8d0e25Snbeams CeedScalar r_V; 1527d8d0e25Snbeams 1537d8d0e25Snbeams const int tidx = threadIdx.x; 1547d8d0e25Snbeams const int tidy = threadIdx.y; 1557d8d0e25Snbeams const int tidz = threadIdx.z; 1567d8d0e25Snbeams int dim; 1577d8d0e25Snbeams 1587d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 1597d8d0e25Snbeams elem += gridDim.x*blockDim.z) { 1607d8d0e25Snbeams for(int comp = 0; comp < BASIS_NCOMP; comp++) { 1617d8d0e25Snbeams if (!transpose) { 1627d8d0e25Snbeams readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 1637d8d0e25Snbeams ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 1647d8d0e25Snbeams dim = 0; 1657d8d0e25Snbeams writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 1667d8d0e25Snbeams } else { 1677d8d0e25Snbeams dim = 0; 1687d8d0e25Snbeams readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice); 1697d8d0e25Snbeams ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 1707d8d0e25Snbeams writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 1717d8d0e25Snbeams } 1727d8d0e25Snbeams } 1737d8d0e25Snbeams } 1747d8d0e25Snbeams } 1757d8d0e25Snbeams 1767d8d0e25Snbeams //------------------------------------------------------------------------------ 1777d8d0e25Snbeams // 1D Quadrature weights 1787d8d0e25Snbeams //------------------------------------------------------------------------------ 1797d8d0e25Snbeams __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d, 1807d8d0e25Snbeams CeedScalar *w) { 1817d8d0e25Snbeams const int tid = threadIdx.x; 1827d8d0e25Snbeams const CeedScalar weight = qweight1d[tid]; 1837d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem; 1847d8d0e25Snbeams elem += gridDim.x*blockDim.y) { 1857d8d0e25Snbeams const int ind = elem*Q1D + tid; 1867d8d0e25Snbeams w[ind] = weight; 1877d8d0e25Snbeams } 1887d8d0e25Snbeams } 1897d8d0e25Snbeams 1907d8d0e25Snbeams //------------------------------------------------------------------------------ 1917d8d0e25Snbeams // 2D 1927d8d0e25Snbeams //------------------------------------------------------------------------------ 1937d8d0e25Snbeams 1947d8d0e25Snbeams //------------------------------------------------------------------------------ 1957d8d0e25Snbeams // Read DoFs 1967d8d0e25Snbeams //------------------------------------------------------------------------------ 1977d8d0e25Snbeams inline __device__ void readDofs2d(const int elem, const int tidx, 1987d8d0e25Snbeams const int tidy, const int comp, 1997d8d0e25Snbeams const int nelem, const CeedScalar *d_U, 2007d8d0e25Snbeams CeedScalar &U) { 2017d8d0e25Snbeams U = (tidx<P1D && tidy<P1D) ? 2027d8d0e25Snbeams d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0; 2037d8d0e25Snbeams } 2047d8d0e25Snbeams 2057d8d0e25Snbeams //------------------------------------------------------------------------------ 2067d8d0e25Snbeams // Write DoFs 2077d8d0e25Snbeams //------------------------------------------------------------------------------ 2087d8d0e25Snbeams inline __device__ void writeDofs2d(const int elem, const int tidx, 2097d8d0e25Snbeams const int tidy, const int comp, 2107d8d0e25Snbeams const int nelem, const CeedScalar &r_V, 2117d8d0e25Snbeams CeedScalar *d_V) { 2127d8d0e25Snbeams if (tidx<P1D && tidy<P1D) 2137d8d0e25Snbeams d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V; 2147d8d0e25Snbeams } 2157d8d0e25Snbeams 2167d8d0e25Snbeams //------------------------------------------------------------------------------ 2177d8d0e25Snbeams // Read quadrature point data 2187d8d0e25Snbeams //------------------------------------------------------------------------------ 2197d8d0e25Snbeams inline __device__ void readQuads2d(const int elem, const int tidx, 2207d8d0e25Snbeams const int tidy, const int comp, 2217d8d0e25Snbeams const int dim, const int nelem, 2227d8d0e25Snbeams const CeedScalar *d_U, CeedScalar &U ) { 2237d8d0e25Snbeams U = (tidx<Q1D && tidy<Q1D) ? 2247d8d0e25Snbeams d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 2257d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0; 2267d8d0e25Snbeams } 2277d8d0e25Snbeams 2287d8d0e25Snbeams //------------------------------------------------------------------------------ 2297d8d0e25Snbeams // Write quadrature point data 2307d8d0e25Snbeams //------------------------------------------------------------------------------ 2317d8d0e25Snbeams inline __device__ void writeQuads2d(const int elem, const int tidx, 2327d8d0e25Snbeams const int tidy, const int comp, 2337d8d0e25Snbeams const int dim, const int nelem, 2347d8d0e25Snbeams const CeedScalar &r_V, CeedScalar *d_V) { 2357d8d0e25Snbeams if (tidx<Q1D && tidy<Q1D) 2367d8d0e25Snbeams d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 2377d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V; 2387d8d0e25Snbeams } 2397d8d0e25Snbeams 2407d8d0e25Snbeams //------------------------------------------------------------------------------ 2417d8d0e25Snbeams // 2D tensor contraction x 2427d8d0e25Snbeams //------------------------------------------------------------------------------ 2437d8d0e25Snbeams inline __device__ void ContractX2d(CeedScalar *slice, const int tidx, 2447d8d0e25Snbeams const int tidy, const int tidz, 2457d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, 2467d8d0e25Snbeams CeedScalar &V) { 2477d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 2487d8d0e25Snbeams __syncthreads(); 2497d8d0e25Snbeams V = 0.0; 2507d8d0e25Snbeams if (tidx < Q1D) 2517d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 2527d8d0e25Snbeams V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 2537d8d0e25Snbeams __syncthreads(); 2547d8d0e25Snbeams } 2557d8d0e25Snbeams 2567d8d0e25Snbeams //------------------------------------------------------------------------------ 2577d8d0e25Snbeams // 2D tensor contraction y 2587d8d0e25Snbeams //------------------------------------------------------------------------------ 2597d8d0e25Snbeams inline __device__ void ContractY2d(CeedScalar *slice, const int tidx, 2607d8d0e25Snbeams const int tidy, const int tidz, 2617d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, 2627d8d0e25Snbeams CeedScalar &V) { 2637d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 2647d8d0e25Snbeams __syncthreads(); 2657d8d0e25Snbeams V = 0.0; 2667d8d0e25Snbeams if (tidy < Q1D) 2677d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 2687d8d0e25Snbeams V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 2697d8d0e25Snbeams __syncthreads(); 2707d8d0e25Snbeams } 2717d8d0e25Snbeams 2727d8d0e25Snbeams //------------------------------------------------------------------------------ 2737d8d0e25Snbeams // 2D transpose tensor contraction y 2747d8d0e25Snbeams //------------------------------------------------------------------------------ 2757d8d0e25Snbeams inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx, 2767d8d0e25Snbeams const int tidy, const int tidz, 2777d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 2787d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 2797d8d0e25Snbeams __syncthreads(); 2807d8d0e25Snbeams V = 0.0; 2817d8d0e25Snbeams if (tidy < P1D) 2827d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 2837d8d0e25Snbeams V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 2847d8d0e25Snbeams __syncthreads(); 2857d8d0e25Snbeams } 2867d8d0e25Snbeams 2877d8d0e25Snbeams //------------------------------------------------------------------------------ 2887d8d0e25Snbeams // 2D transpose tensor contraction x 2897d8d0e25Snbeams //------------------------------------------------------------------------------ 2907d8d0e25Snbeams inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx, 2917d8d0e25Snbeams const int tidy, const int tidz, 2927d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 2937d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 2947d8d0e25Snbeams __syncthreads(); 2957d8d0e25Snbeams V = 0.0; 2967d8d0e25Snbeams if (tidx < P1D) 2977d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 2987d8d0e25Snbeams V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 2997d8d0e25Snbeams __syncthreads(); 3007d8d0e25Snbeams } 3017d8d0e25Snbeams 3027d8d0e25Snbeams //------------------------------------------------------------------------------ 3037d8d0e25Snbeams // 2D interpolate to quadrature points 3047d8d0e25Snbeams //------------------------------------------------------------------------------ 3057d8d0e25Snbeams inline __device__ void interp2d(const CeedInt nelem, const int transpose, 3067d8d0e25Snbeams const CeedScalar *c_B, 3077d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 3087d8d0e25Snbeams CeedScalar *__restrict__ d_V, 3097d8d0e25Snbeams CeedScalar *slice) { 3107d8d0e25Snbeams CeedScalar r_V; 3117d8d0e25Snbeams CeedScalar r_t; 3127d8d0e25Snbeams 3137d8d0e25Snbeams const int tidx = threadIdx.x; 3147d8d0e25Snbeams const int tidy = threadIdx.y; 3157d8d0e25Snbeams const int tidz = threadIdx.z; 3167d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 3177d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 3187d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 3197d8d0e25Snbeams 3207d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 3217d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 3227d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 3237d8d0e25Snbeams r_V = 0.0; 3247d8d0e25Snbeams r_t = 0.0; 3257d8d0e25Snbeams if (!transpose) { 3267d8d0e25Snbeams readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V); 3277d8d0e25Snbeams ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 3287d8d0e25Snbeams ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 3297d8d0e25Snbeams writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 3307d8d0e25Snbeams } else { 3317d8d0e25Snbeams readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 3327d8d0e25Snbeams ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 3337d8d0e25Snbeams ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 3347d8d0e25Snbeams writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 3357d8d0e25Snbeams } 3367d8d0e25Snbeams } 3377d8d0e25Snbeams } 3387d8d0e25Snbeams 3397d8d0e25Snbeams //------------------------------------------------------------------------------ 3407d8d0e25Snbeams // 2D derivatives at quadrature points 3417d8d0e25Snbeams //------------------------------------------------------------------------------ 3427d8d0e25Snbeams inline __device__ void grad2d(const CeedInt nelem, const int transpose, 3437d8d0e25Snbeams const CeedScalar *c_B, const CeedScalar *c_G, 3447d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 3457d8d0e25Snbeams CeedScalar *__restrict__ d_V, CeedScalar *slice) { 3467d8d0e25Snbeams CeedScalar r_U; 3477d8d0e25Snbeams CeedScalar r_V; 3487d8d0e25Snbeams CeedScalar r_t; 3497d8d0e25Snbeams 3507d8d0e25Snbeams const int tidx = threadIdx.x; 3517d8d0e25Snbeams const int tidy = threadIdx.y; 3527d8d0e25Snbeams const int tidz = threadIdx.z; 3537d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 3547d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 3557d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 3567d8d0e25Snbeams int dim; 3577d8d0e25Snbeams 3587d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 3597d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 3607d8d0e25Snbeams if (!transpose) { 3617d8d0e25Snbeams readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U); 3627d8d0e25Snbeams ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 3637d8d0e25Snbeams ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 3647d8d0e25Snbeams dim = 0; 3657d8d0e25Snbeams writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 3667d8d0e25Snbeams ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 3677d8d0e25Snbeams ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 3687d8d0e25Snbeams dim = 1; 3697d8d0e25Snbeams writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 3707d8d0e25Snbeams } else { 3717d8d0e25Snbeams dim = 0; 3727d8d0e25Snbeams readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 3737d8d0e25Snbeams ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 3747d8d0e25Snbeams ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 3757d8d0e25Snbeams dim = 1; 3767d8d0e25Snbeams readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 3777d8d0e25Snbeams ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 3787d8d0e25Snbeams ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 3797d8d0e25Snbeams r_V += r_U; 3807d8d0e25Snbeams writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 3817d8d0e25Snbeams } 3827d8d0e25Snbeams } 3837d8d0e25Snbeams } 3847d8d0e25Snbeams 3857d8d0e25Snbeams //------------------------------------------------------------------------------ 3867d8d0e25Snbeams // 2D quadrature weights 3877d8d0e25Snbeams //------------------------------------------------------------------------------ 3887d8d0e25Snbeams __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d, 3897d8d0e25Snbeams CeedScalar *w) { 3907d8d0e25Snbeams const int i = threadIdx.x; 3917d8d0e25Snbeams const int j = threadIdx.y; 3927d8d0e25Snbeams const CeedScalar weight = qweight1d[i]*qweight1d[j]; 3937d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 3947d8d0e25Snbeams elem += gridDim.x*blockDim.z) { 3957d8d0e25Snbeams const int ind = elem*Q1D*Q1D + i + j*Q1D; 3967d8d0e25Snbeams w[ind] = weight; 3977d8d0e25Snbeams } 3987d8d0e25Snbeams } 3997d8d0e25Snbeams 4007d8d0e25Snbeams //------------------------------------------------------------------------------ 4017d8d0e25Snbeams // 3D 4027d8d0e25Snbeams //------------------------------------------------------------------------------ 4037d8d0e25Snbeams 4047d8d0e25Snbeams //------------------------------------------------------------------------------ 4057d8d0e25Snbeams // Read DoFs 4067d8d0e25Snbeams //------------------------------------------------------------------------------ 4077d8d0e25Snbeams inline __device__ void readDofs3d(const int elem, const int tidx, 4087d8d0e25Snbeams const int tidy, const int comp, 4097d8d0e25Snbeams const int nelem, const CeedScalar *d_U, 4107d8d0e25Snbeams CeedScalar *r_U) { 4117d8d0e25Snbeams for (int i = 0; i < P1D; i++) 4127d8d0e25Snbeams r_U[i] = (tidx < P1D && tidy < P1D) ? 4137d8d0e25Snbeams d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D + 4147d8d0e25Snbeams comp*P1D*P1D*P1D*nelem] : 0.0; 4157d8d0e25Snbeams for (int i = P1D; i < Q1D; i++) 4167d8d0e25Snbeams r_U[i] = 0.0; 4177d8d0e25Snbeams } 4187d8d0e25Snbeams 4197d8d0e25Snbeams //------------------------------------------------------------------------------ 4207d8d0e25Snbeams // Write DoFs 4217d8d0e25Snbeams //------------------------------------------------------------------------------ 4227d8d0e25Snbeams inline __device__ void writeDofs3d(const int elem, const int tidx, 4237d8d0e25Snbeams const int tidy, const int comp, 4247d8d0e25Snbeams const int nelem, const CeedScalar *r_V, 4257d8d0e25Snbeams CeedScalar *d_V) { 4267d8d0e25Snbeams if (tidx < P1D && tidy < P1D) { 4277d8d0e25Snbeams for (int i = 0; i < P1D; i++) 4287d8d0e25Snbeams d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D + 4297d8d0e25Snbeams comp*P1D*P1D*P1D*nelem] = r_V[i]; 4307d8d0e25Snbeams } 4317d8d0e25Snbeams } 4327d8d0e25Snbeams 4337d8d0e25Snbeams //------------------------------------------------------------------------------ 4347d8d0e25Snbeams // Read quadrature point data 4357d8d0e25Snbeams //------------------------------------------------------------------------------ 4367d8d0e25Snbeams inline __device__ void readQuads3d(const int elem, const int tidx, 4377d8d0e25Snbeams const int tidy, const int comp, 4387d8d0e25Snbeams const int dim, const int nelem, 4397d8d0e25Snbeams const CeedScalar *d_U, CeedScalar *r_U) { 4407d8d0e25Snbeams for (int i = 0; i < Q1D; i++) 4417d8d0e25Snbeams r_U[i] = (tidx < Q1D && tidy < Q1D) ? 4427d8d0e25Snbeams d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + 4437d8d0e25Snbeams comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0; 4447d8d0e25Snbeams for (int i = Q1D; i < P1D; i++) 4457d8d0e25Snbeams r_U[i] = 0.0; 4467d8d0e25Snbeams } 4477d8d0e25Snbeams 4487d8d0e25Snbeams //------------------------------------------------------------------------------ 4497d8d0e25Snbeams // Write quadrature point data 4507d8d0e25Snbeams //------------------------------------------------------------------------------ 4517d8d0e25Snbeams inline __device__ void writeQuads3d(const int elem, const int tidx, 4527d8d0e25Snbeams const int tidy, const int comp, 4537d8d0e25Snbeams const int dim, const int nelem, 4547d8d0e25Snbeams const CeedScalar *r_V, CeedScalar *d_V) { 4557d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) { 4567d8d0e25Snbeams for (int i = 0; i < Q1D; i++) 4577d8d0e25Snbeams d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem + 4587d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i]; 4597d8d0e25Snbeams } 4607d8d0e25Snbeams } 4617d8d0e25Snbeams 4627d8d0e25Snbeams //------------------------------------------------------------------------------ 4637d8d0e25Snbeams // 3D tensor contract x 4647d8d0e25Snbeams //------------------------------------------------------------------------------ 4657d8d0e25Snbeams inline __device__ void ContractX3d(CeedScalar *slice, const int tidx, 4667d8d0e25Snbeams const int tidy, const int tidz, 4677d8d0e25Snbeams const CeedScalar *U, 4687d8d0e25Snbeams const CeedScalar *B, 4697d8d0e25Snbeams CeedScalar *V) { 4707d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 4717d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 4727d8d0e25Snbeams __syncthreads(); 4737d8d0e25Snbeams V[k] = 0.0; 4747d8d0e25Snbeams if (tidx < Q1D && tidy < P1D) 4757d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 4767d8d0e25Snbeams V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 4777d8d0e25Snbeams __syncthreads(); 4787d8d0e25Snbeams } 4797d8d0e25Snbeams } 4807d8d0e25Snbeams 4817d8d0e25Snbeams //------------------------------------------------------------------------------ 4827d8d0e25Snbeams // 3D tensor contract y 4837d8d0e25Snbeams //------------------------------------------------------------------------------ 4847d8d0e25Snbeams inline __device__ void ContractY3d(CeedScalar *slice, const int tidx, 4857d8d0e25Snbeams const int tidy, const int tidz, 4867d8d0e25Snbeams const CeedScalar *U, 4877d8d0e25Snbeams const CeedScalar *B, 4887d8d0e25Snbeams CeedScalar *V) { 4897d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 4907d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 4917d8d0e25Snbeams __syncthreads(); 4927d8d0e25Snbeams V[k] = 0.0; 4937d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) 4947d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 4957d8d0e25Snbeams V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 4967d8d0e25Snbeams __syncthreads(); 4977d8d0e25Snbeams } 4987d8d0e25Snbeams } 4997d8d0e25Snbeams 5007d8d0e25Snbeams //------------------------------------------------------------------------------ 5017d8d0e25Snbeams // 3D tensor contract z 5027d8d0e25Snbeams //------------------------------------------------------------------------------ 5037d8d0e25Snbeams inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx, 5047d8d0e25Snbeams const int tidy, const int tidz, 5057d8d0e25Snbeams const CeedScalar *U, 5067d8d0e25Snbeams const CeedScalar *B, 5077d8d0e25Snbeams CeedScalar *V) { 5087d8d0e25Snbeams for (int k = 0; k < Q1D; ++k) { 5097d8d0e25Snbeams V[k] = 0.0; 5107d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) 5117d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 5127d8d0e25Snbeams V[k] += B[i + k*P1D] * U[i]; // Contract z direction 5137d8d0e25Snbeams } 5147d8d0e25Snbeams for (int k = Q1D; k < P1D; ++k) 5157d8d0e25Snbeams V[k] = 0.0; 5167d8d0e25Snbeams } 5177d8d0e25Snbeams 5187d8d0e25Snbeams //------------------------------------------------------------------------------ 5197d8d0e25Snbeams // 3D transpose tensor contract z 5207d8d0e25Snbeams //------------------------------------------------------------------------------ 5217d8d0e25Snbeams inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx, 5227d8d0e25Snbeams const int tidy, const int tidz, 5237d8d0e25Snbeams const CeedScalar *U, 5247d8d0e25Snbeams const CeedScalar *B, 5257d8d0e25Snbeams CeedScalar *V) { 5267d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 5277d8d0e25Snbeams V[k] = 0.0; 5287d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) 5297d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 5307d8d0e25Snbeams V[k] += B[k + i*P1D] * U[i]; // Contract z direction 5317d8d0e25Snbeams } 5327d8d0e25Snbeams for (int k = P1D; k < Q1D; ++k) 5337d8d0e25Snbeams V[k] = 0.0; 5347d8d0e25Snbeams } 5357d8d0e25Snbeams 5367d8d0e25Snbeams //------------------------------------------------------------------------------ 5377d8d0e25Snbeams // 3D transpose tensor contract y 5387d8d0e25Snbeams //------------------------------------------------------------------------------ 5397d8d0e25Snbeams inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx, 5407d8d0e25Snbeams const int tidy, const int tidz, 5417d8d0e25Snbeams const CeedScalar *U, 5427d8d0e25Snbeams const CeedScalar *B, 5437d8d0e25Snbeams CeedScalar *V) { 5447d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 5457d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 5467d8d0e25Snbeams __syncthreads(); 5477d8d0e25Snbeams V[k] = 0.0; 5487d8d0e25Snbeams if (tidx < Q1D && tidy < P1D) 5497d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 5507d8d0e25Snbeams V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 5517d8d0e25Snbeams __syncthreads(); 5527d8d0e25Snbeams } 5537d8d0e25Snbeams } 5547d8d0e25Snbeams 5557d8d0e25Snbeams //------------------------------------------------------------------------------ 5567d8d0e25Snbeams // 3D transpose tensor contract x 5577d8d0e25Snbeams //------------------------------------------------------------------------------ 5587d8d0e25Snbeams inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx, 5597d8d0e25Snbeams const int tidy, const int tidz, 5607d8d0e25Snbeams const CeedScalar *U, 5617d8d0e25Snbeams const CeedScalar *B, 5627d8d0e25Snbeams CeedScalar *V) { 5637d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 5647d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 5657d8d0e25Snbeams __syncthreads(); 5667d8d0e25Snbeams V[k] = 0.0; 5677d8d0e25Snbeams if (tidx < P1D && tidy < P1D) 5687d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 5697d8d0e25Snbeams V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 5707d8d0e25Snbeams __syncthreads(); 5717d8d0e25Snbeams } 5727d8d0e25Snbeams } 5737d8d0e25Snbeams 5747d8d0e25Snbeams //------------------------------------------------------------------------------ 5757d8d0e25Snbeams // 3D interpolate to quadrature points 5767d8d0e25Snbeams //------------------------------------------------------------------------------ 5777d8d0e25Snbeams inline __device__ void interp3d(const CeedInt nelem, const int transpose, 5787d8d0e25Snbeams const CeedScalar *c_B, 5797d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 5807d8d0e25Snbeams CeedScalar *__restrict__ d_V, 5817d8d0e25Snbeams CeedScalar *slice) { 5827d8d0e25Snbeams CeedScalar r_V[T1D]; 5837d8d0e25Snbeams CeedScalar r_t[T1D]; 5847d8d0e25Snbeams 5857d8d0e25Snbeams const int tidx = threadIdx.x; 5867d8d0e25Snbeams const int tidy = threadIdx.y; 5877d8d0e25Snbeams const int tidz = threadIdx.z; 5887d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 5897d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 5907d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 5917d8d0e25Snbeams 5927d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 5937d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 5947d8d0e25Snbeams for (int i = 0; i < T1D; ++i) { 5957d8d0e25Snbeams r_V[i] = 0.0; 5967d8d0e25Snbeams r_t[i] = 0.0; 5977d8d0e25Snbeams } 5987d8d0e25Snbeams if (!transpose) { 5997d8d0e25Snbeams readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V); 6007d8d0e25Snbeams ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 6017d8d0e25Snbeams ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 6027d8d0e25Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 6037d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V); 6047d8d0e25Snbeams } else { 6057d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 6067d8d0e25Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 6077d8d0e25Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 6087d8d0e25Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 6097d8d0e25Snbeams writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V); 6107d8d0e25Snbeams } 6117d8d0e25Snbeams } 6127d8d0e25Snbeams } 6137d8d0e25Snbeams 6147d8d0e25Snbeams //------------------------------------------------------------------------------ 6157d8d0e25Snbeams // 3D derivatives at quadrature points 6167d8d0e25Snbeams //------------------------------------------------------------------------------ 6177d8d0e25Snbeams inline __device__ void grad3d(const CeedInt nelem, const int transpose, 6187d8d0e25Snbeams const CeedScalar *c_B, const CeedScalar *c_G, 6197d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 6207d8d0e25Snbeams CeedScalar *__restrict__ d_V, 6217d8d0e25Snbeams CeedScalar *slice) { 6227d8d0e25Snbeams // Use P1D for one of these 6237d8d0e25Snbeams CeedScalar r_U[T1D]; 6247d8d0e25Snbeams CeedScalar r_V[T1D]; 6257d8d0e25Snbeams CeedScalar r_t[T1D]; 6267d8d0e25Snbeams 6277d8d0e25Snbeams const int tidx = threadIdx.x; 6287d8d0e25Snbeams const int tidy = threadIdx.y; 6297d8d0e25Snbeams const int tidz = threadIdx.z; 6307d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 6317d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 6327d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 6337d8d0e25Snbeams int dim; 6347d8d0e25Snbeams 6357d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 6367d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 6377d8d0e25Snbeams for (int i = 0; i < T1D; ++i) { 6387d8d0e25Snbeams r_U[i] = 0.0; 6397d8d0e25Snbeams r_V[i] = 0.0; 6407d8d0e25Snbeams r_t[i] = 0.0; 6417d8d0e25Snbeams } 6427d8d0e25Snbeams if (!transpose) { 6437d8d0e25Snbeams readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U); 6447d8d0e25Snbeams ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 6457d8d0e25Snbeams ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 6467d8d0e25Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 6477d8d0e25Snbeams dim = 0; 6487d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 6497d8d0e25Snbeams ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V); 6507d8d0e25Snbeams ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t); 6517d8d0e25Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 6527d8d0e25Snbeams dim = 1; 6537d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 6547d8d0e25Snbeams ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V); 6557d8d0e25Snbeams ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 6567d8d0e25Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 6577d8d0e25Snbeams dim = 2; 6587d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 6597d8d0e25Snbeams } else { 6607d8d0e25Snbeams dim = 0; 6617d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 6627d8d0e25Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 6637d8d0e25Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 6647d8d0e25Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 6657d8d0e25Snbeams dim = 1; 6667d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 6677d8d0e25Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 6687d8d0e25Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U); 6697d8d0e25Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 6707d8d0e25Snbeams add(r_V, r_t); 6717d8d0e25Snbeams dim = 2; 6727d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 6737d8d0e25Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 6747d8d0e25Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 6757d8d0e25Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 6767d8d0e25Snbeams add(r_V, r_t); 6777d8d0e25Snbeams writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V); 6787d8d0e25Snbeams } 6797d8d0e25Snbeams } 6807d8d0e25Snbeams } 6817d8d0e25Snbeams 6827d8d0e25Snbeams //------------------------------------------------------------------------------ 6837d8d0e25Snbeams // 3D quadrature weights 6847d8d0e25Snbeams //------------------------------------------------------------------------------ 6857d8d0e25Snbeams __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d, 6867d8d0e25Snbeams CeedScalar *w) { 6877d8d0e25Snbeams const int i = threadIdx.x; 6887d8d0e25Snbeams const int j = threadIdx.y; 6897d8d0e25Snbeams const int k = threadIdx.z; 6907d8d0e25Snbeams const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k]; 6917d8d0e25Snbeams for (int e = blockIdx.x; e < nelem; e += gridDim.x) { 6927d8d0e25Snbeams const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D; 6937d8d0e25Snbeams w[ind] = weight; 6947d8d0e25Snbeams } 6957d8d0e25Snbeams } 6967d8d0e25Snbeams 6977d8d0e25Snbeams 6987d8d0e25Snbeams //------------------------------------------------------------------------------ 6997d8d0e25Snbeams // Basis kernels 7007d8d0e25Snbeams //------------------------------------------------------------------------------ 7017d8d0e25Snbeams 7027d8d0e25Snbeams //------------------------------------------------------------------------------ 7037d8d0e25Snbeams // Interp kernel by dim 7047d8d0e25Snbeams //------------------------------------------------------------------------------ 705*9e31c45bSnbeams extern "C" __launch_bounds__(INTERP_BLKSIZE) __global__ void interp( 706*9e31c45bSnbeams const CeedInt nelem, const int transpose, 7077d8d0e25Snbeams const CeedScalar *c_B, 7087d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 7097d8d0e25Snbeams CeedScalar *__restrict__ d_V) { 7107d8d0e25Snbeams HIP_DYNAMIC_SHARED( double, slice) 7117d8d0e25Snbeams if (BASIS_DIM == 1) { 7127d8d0e25Snbeams interp1d(nelem, transpose, c_B, d_U, d_V, slice); 7137d8d0e25Snbeams } else if (BASIS_DIM == 2) { 7147d8d0e25Snbeams interp2d(nelem, transpose, c_B, d_U, d_V, slice); 7157d8d0e25Snbeams } else if (BASIS_DIM == 3) { 7167d8d0e25Snbeams interp3d(nelem, transpose, c_B, d_U, d_V, slice); 7177d8d0e25Snbeams } 7187d8d0e25Snbeams } 7197d8d0e25Snbeams 7207d8d0e25Snbeams //------------------------------------------------------------------------------ 7217d8d0e25Snbeams // Grad kernel by dim 7227d8d0e25Snbeams //------------------------------------------------------------------------------ 723*9e31c45bSnbeams extern "C" __launch_bounds__(GRAD_BLKSIZE) __global__ void grad(const CeedInt nelem, 724*9e31c45bSnbeams const int transpose, 7257d8d0e25Snbeams const CeedScalar *c_B, const CeedScalar *c_G, 7267d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 7277d8d0e25Snbeams CeedScalar *__restrict__ d_V) { 7287d8d0e25Snbeams HIP_DYNAMIC_SHARED( double, slice) 7297d8d0e25Snbeams if (BASIS_DIM == 1) { 7307d8d0e25Snbeams grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 7317d8d0e25Snbeams } else if (BASIS_DIM == 2) { 7327d8d0e25Snbeams grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 7337d8d0e25Snbeams } else if (BASIS_DIM == 3) { 7347d8d0e25Snbeams grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 7357d8d0e25Snbeams } 7367d8d0e25Snbeams } 7377d8d0e25Snbeams 7387d8d0e25Snbeams //------------------------------------------------------------------------------ 7397d8d0e25Snbeams // Weight kernels by dim 7407d8d0e25Snbeams //------------------------------------------------------------------------------ 741*9e31c45bSnbeams extern "C" __launch_bounds__(WEIGHT_BLKSIZE) __global__ void weight(const CeedInt nelem, 7427d8d0e25Snbeams const CeedScalar *__restrict__ qweight1d, 7437d8d0e25Snbeams CeedScalar *__restrict__ v) { 7447d8d0e25Snbeams if (BASIS_DIM == 1) { 7457d8d0e25Snbeams weight1d(nelem, qweight1d, v); 7467d8d0e25Snbeams } else if (BASIS_DIM == 2) { 7477d8d0e25Snbeams weight2d(nelem, qweight1d, v); 7487d8d0e25Snbeams } else if (BASIS_DIM == 3) { 7497d8d0e25Snbeams weight3d(nelem, qweight1d, v); 7507d8d0e25Snbeams } 7517d8d0e25Snbeams } 7527d8d0e25Snbeams 7537d8d0e25Snbeams ); 7547d8d0e25Snbeams // *INDENT-ON* 7557d8d0e25Snbeams 7567d8d0e25Snbeams //------------------------------------------------------------------------------ 757*9e31c45bSnbeams // Compute a block size based on required minimum threads 758*9e31c45bSnbeams //------------------------------------------------------------------------------ 759*9e31c45bSnbeams static CeedInt ComputeBlockSizeFromRequirement(const CeedInt required) { 760*9e31c45bSnbeams CeedInt maxSize = 1024; // Max total threads per block 761*9e31c45bSnbeams CeedInt currentSize = 64; // Start with one group 762*9e31c45bSnbeams 763*9e31c45bSnbeams while(currentSize < maxSize) { 764*9e31c45bSnbeams if (currentSize > required) 765*9e31c45bSnbeams break; 766*9e31c45bSnbeams else 767*9e31c45bSnbeams currentSize = currentSize * 2; 768*9e31c45bSnbeams } 769*9e31c45bSnbeams return currentSize; 770*9e31c45bSnbeams } 771*9e31c45bSnbeams 772*9e31c45bSnbeams //------------------------------------------------------------------------------ 773*9e31c45bSnbeams // Compute required thread block sizes for basis kernels given P, Q, dim, and 774*9e31c45bSnbeams // ncomp 775*9e31c45bSnbeams //------------------------------------------------------------------------------ 776*9e31c45bSnbeams static int ComputeBasisThreadBlockSizes(const CeedInt dim, const CeedInt P1d, 777*9e31c45bSnbeams const CeedInt Q1d, 778*9e31c45bSnbeams const CeedInt ncomp, CeedInt *blksizes) { 779*9e31c45bSnbeams 780*9e31c45bSnbeams // Note that this will use the same block sizes for all dimensions when compiling, 781*9e31c45bSnbeams // but as each basis object is defined for a particular dimension, we will never 782*9e31c45bSnbeams // call any kernels except the ones for the dimension for which we have computed the 783*9e31c45bSnbeams // block sizes. 784*9e31c45bSnbeams const CeedInt thread1d = CeedIntMax(P1d, Q1d); 785*9e31c45bSnbeams switch (dim) { 786*9e31c45bSnbeams case 1: { 787*9e31c45bSnbeams // Interp kernels: 788*9e31c45bSnbeams blksizes[0] = 256; 789*9e31c45bSnbeams 790*9e31c45bSnbeams // Grad kernels: 791*9e31c45bSnbeams blksizes[1] = 256; 792*9e31c45bSnbeams 793*9e31c45bSnbeams // Weight kernels: 794*9e31c45bSnbeams blksizes[2] = 256; 795*9e31c45bSnbeams 796*9e31c45bSnbeams } break; 797*9e31c45bSnbeams case 2: { 798*9e31c45bSnbeams // Interp kernels: 799*9e31c45bSnbeams CeedInt required = thread1d * thread1d * ncomp; 800*9e31c45bSnbeams blksizes[0] = ComputeBlockSizeFromRequirement(required); 801*9e31c45bSnbeams 802*9e31c45bSnbeams // Grad kernels: currently use same required minimum threads 803*9e31c45bSnbeams blksizes[1] = ComputeBlockSizeFromRequirement(required); 804*9e31c45bSnbeams 805*9e31c45bSnbeams // Weight kernels: 806*9e31c45bSnbeams required = CeedIntMax(64, Q1d * Q1d); 807*9e31c45bSnbeams blksizes[2] = ComputeBlockSizeFromRequirement(required); 808*9e31c45bSnbeams 809*9e31c45bSnbeams } break; 810*9e31c45bSnbeams case 3: { 811*9e31c45bSnbeams // Interp kernels: 812*9e31c45bSnbeams CeedInt required = thread1d * thread1d * ncomp; 813*9e31c45bSnbeams blksizes[0] = ComputeBlockSizeFromRequirement(required); 814*9e31c45bSnbeams 815*9e31c45bSnbeams // Grad kernels: currently use same required minimum threads 816*9e31c45bSnbeams blksizes[1] = ComputeBlockSizeFromRequirement(required); 817*9e31c45bSnbeams 818*9e31c45bSnbeams // Weight kernels: 819*9e31c45bSnbeams required = Q1d * Q1d * Q1d; 820*9e31c45bSnbeams blksizes[2] = ComputeBlockSizeFromRequirement(required); 821*9e31c45bSnbeams } 822*9e31c45bSnbeams } 823*9e31c45bSnbeams 824*9e31c45bSnbeams return 0; 825*9e31c45bSnbeams } 826*9e31c45bSnbeams 827*9e31c45bSnbeams //------------------------------------------------------------------------------ 8287d8d0e25Snbeams // Device initalization 8297d8d0e25Snbeams //------------------------------------------------------------------------------ 8307d8d0e25Snbeams int CeedHipInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d, 8317d8d0e25Snbeams CeedScalar **c_B); 8327d8d0e25Snbeams int CeedHipInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d, 8337d8d0e25Snbeams CeedInt Q1d, CeedScalar **c_B_ptr, 8347d8d0e25Snbeams CeedScalar **c_G_ptr); 8357d8d0e25Snbeams 8367d8d0e25Snbeams //------------------------------------------------------------------------------ 8377d8d0e25Snbeams // Apply basis 8387d8d0e25Snbeams //------------------------------------------------------------------------------ 8397d8d0e25Snbeams int CeedBasisApplyTensor_Hip_shared(CeedBasis basis, const CeedInt nelem, 8407d8d0e25Snbeams CeedTransposeMode tmode, 8417d8d0e25Snbeams CeedEvalMode emode, CeedVector u, 8427d8d0e25Snbeams CeedVector v) { 8437d8d0e25Snbeams int ierr; 8447d8d0e25Snbeams Ceed ceed; 8457d8d0e25Snbeams ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr); 8467d8d0e25Snbeams Ceed_Hip_shared *ceed_Hip; 8477d8d0e25Snbeams CeedGetData(ceed, &ceed_Hip); CeedChk(ierr); 8487d8d0e25Snbeams CeedBasis_Hip_shared *data; 8497d8d0e25Snbeams CeedBasisGetData(basis, &data); CeedChk(ierr); 8507d8d0e25Snbeams const CeedInt transpose = tmode == CEED_TRANSPOSE; 8517d8d0e25Snbeams CeedInt dim, ncomp; 8527d8d0e25Snbeams ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr); 8537d8d0e25Snbeams ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr); 8547d8d0e25Snbeams 8557d8d0e25Snbeams // Read vectors 8567d8d0e25Snbeams const CeedScalar *d_u; 8577d8d0e25Snbeams CeedScalar *d_v; 8587d8d0e25Snbeams if (emode != CEED_EVAL_WEIGHT) { 8597d8d0e25Snbeams ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr); 8607d8d0e25Snbeams } 8617d8d0e25Snbeams ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr); 8627d8d0e25Snbeams 8637d8d0e25Snbeams // Clear v for transpose mode 8647d8d0e25Snbeams if (tmode == CEED_TRANSPOSE) { 8657d8d0e25Snbeams CeedInt length; 8667d8d0e25Snbeams ierr = CeedVectorGetLength(v, &length); CeedChk(ierr); 8677d8d0e25Snbeams ierr = hipMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr); 8687d8d0e25Snbeams } 8697d8d0e25Snbeams 8707d8d0e25Snbeams // Apply basis operation 8717d8d0e25Snbeams switch (emode) { 8727d8d0e25Snbeams case CEED_EVAL_INTERP: { 8737d8d0e25Snbeams CeedInt P1d, Q1d; 874*9e31c45bSnbeams CeedInt blksize = data->blksizes[0]; 8757d8d0e25Snbeams ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr); 8767d8d0e25Snbeams ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr); 8777d8d0e25Snbeams CeedInt thread1d = CeedIntMax(Q1d, P1d); 8787d8d0e25Snbeams ierr = CeedHipInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B); 8797d8d0e25Snbeams CeedChk(ierr); 8807d8d0e25Snbeams void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, 8817d8d0e25Snbeams &d_u, &d_v 8827d8d0e25Snbeams }; 8837d8d0e25Snbeams if (dim == 1) { 884e7ea6884Snbeams CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64; 8857d8d0e25Snbeams elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1; 8867d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 8877d8d0e25Snbeams ? 1 : 0 ); 8887d8d0e25Snbeams CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar); 8897d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, 1, 8907d8d0e25Snbeams elemsPerBlock, sharedMem, 8917d8d0e25Snbeams interpargs); CeedChk(ierr); 8927d8d0e25Snbeams } else if (dim == 2) { 893*9e31c45bSnbeams // Check if required threads is small enough to do multiple elems 894*9e31c45bSnbeams const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1); 8957d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 8967d8d0e25Snbeams ? 1 : 0 ); 8977d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 8987d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d, 8997d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 9007d8d0e25Snbeams interpargs); CeedChk(ierr); 9017d8d0e25Snbeams } else if (dim == 3) { 9027d8d0e25Snbeams CeedInt elemsPerBlock = 1; 9037d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 9047d8d0e25Snbeams ? 1 : 0 ); 9057d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 9067d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d, 9077d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 9087d8d0e25Snbeams interpargs); CeedChk(ierr); 9097d8d0e25Snbeams } 9107d8d0e25Snbeams } break; 9117d8d0e25Snbeams case CEED_EVAL_GRAD: { 9127d8d0e25Snbeams CeedInt P1d, Q1d; 913*9e31c45bSnbeams CeedInt blksize = data->blksizes[1]; 9147d8d0e25Snbeams ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr); 9157d8d0e25Snbeams ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr); 9167d8d0e25Snbeams CeedInt thread1d = CeedIntMax(Q1d, P1d); 9177d8d0e25Snbeams ierr = CeedHipInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d, 9187d8d0e25Snbeams Q1d, &data->c_B, &data->c_G); 9197d8d0e25Snbeams CeedChk(ierr); 9207d8d0e25Snbeams void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, 9217d8d0e25Snbeams &data->c_G, &d_u, &d_v 9227d8d0e25Snbeams }; 9237d8d0e25Snbeams if (dim == 1) { 924e7ea6884Snbeams CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64; 9257d8d0e25Snbeams elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1; 9267d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 9277d8d0e25Snbeams ? 1 : 0 ); 9287d8d0e25Snbeams CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar); 9297d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, 1, 9307d8d0e25Snbeams elemsPerBlock, sharedMem, gradargs); 9317d8d0e25Snbeams CeedChk(ierr); 9327d8d0e25Snbeams } else if (dim == 2) { 933*9e31c45bSnbeams // Check if required threads is small enough to do multiple elems 934*9e31c45bSnbeams const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1); 9357d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 9367d8d0e25Snbeams ? 1 : 0 ); 9377d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 9387d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d, 9397d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 9407d8d0e25Snbeams gradargs); CeedChk(ierr); 9417d8d0e25Snbeams } else if (dim == 3) { 9427d8d0e25Snbeams CeedInt elemsPerBlock = 1; 9437d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 9447d8d0e25Snbeams ? 1 : 0 ); 9457d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 9467d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d, 9477d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 9487d8d0e25Snbeams gradargs); CeedChk(ierr); 9497d8d0e25Snbeams } 9507d8d0e25Snbeams } break; 9517d8d0e25Snbeams case CEED_EVAL_WEIGHT: { 9527d8d0e25Snbeams CeedInt Q1d; 953*9e31c45bSnbeams CeedInt blksize = data->blksizes[2]; 9547d8d0e25Snbeams ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr); 9557d8d0e25Snbeams void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v}; 9567d8d0e25Snbeams if (dim == 1) { 957*9e31c45bSnbeams const CeedInt optElems = blksize/Q1d; 9587d8d0e25Snbeams const CeedInt elemsPerBlock = optElems>0?optElems:1; 9597d8d0e25Snbeams const CeedInt gridsize = nelem/elemsPerBlock + ( ( 9607d8d0e25Snbeams nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 9617d8d0e25Snbeams ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, 9627d8d0e25Snbeams elemsPerBlock, 1, weightargs); 9637d8d0e25Snbeams CeedChk(ierr); 9647d8d0e25Snbeams } else if (dim == 2) { 965*9e31c45bSnbeams const CeedInt optElems = blksize/(Q1d*Q1d); 9667d8d0e25Snbeams const CeedInt elemsPerBlock = optElems>0?optElems:1; 9677d8d0e25Snbeams const CeedInt gridsize = nelem/elemsPerBlock + ( ( 9687d8d0e25Snbeams nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 9697d8d0e25Snbeams ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, 9707d8d0e25Snbeams elemsPerBlock, weightargs); 9717d8d0e25Snbeams CeedChk(ierr); 9727d8d0e25Snbeams } else if (dim == 3) { 9737d8d0e25Snbeams const CeedInt gridsize = nelem; 9747d8d0e25Snbeams ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, Q1d, 9757d8d0e25Snbeams weightargs); 9767d8d0e25Snbeams CeedChk(ierr); 9777d8d0e25Snbeams } 9787d8d0e25Snbeams } break; 9797d8d0e25Snbeams // LCOV_EXCL_START 9807d8d0e25Snbeams // Evaluate the divergence to/from the quadrature points 9817d8d0e25Snbeams case CEED_EVAL_DIV: 9827d8d0e25Snbeams return CeedError(ceed, 1, "CEED_EVAL_DIV not supported"); 9837d8d0e25Snbeams // Evaluate the curl to/from the quadrature points 9847d8d0e25Snbeams case CEED_EVAL_CURL: 9857d8d0e25Snbeams return CeedError(ceed, 1, "CEED_EVAL_CURL not supported"); 9867d8d0e25Snbeams // Take no action, BasisApply should not have been called 9877d8d0e25Snbeams case CEED_EVAL_NONE: 9887d8d0e25Snbeams return CeedError(ceed, 1, 9897d8d0e25Snbeams "CEED_EVAL_NONE does not make sense in this context"); 9907d8d0e25Snbeams // LCOV_EXCL_STOP 9917d8d0e25Snbeams } 9927d8d0e25Snbeams 9937d8d0e25Snbeams // Restore vectors 9947d8d0e25Snbeams if (emode != CEED_EVAL_WEIGHT) { 9957d8d0e25Snbeams ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr); 9967d8d0e25Snbeams } 9977d8d0e25Snbeams ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr); 9987d8d0e25Snbeams return 0; 9997d8d0e25Snbeams } 10007d8d0e25Snbeams 10017d8d0e25Snbeams //------------------------------------------------------------------------------ 10027d8d0e25Snbeams // Destroy basis 10037d8d0e25Snbeams //------------------------------------------------------------------------------ 10047d8d0e25Snbeams static int CeedBasisDestroy_Hip_shared(CeedBasis basis) { 10057d8d0e25Snbeams int ierr; 10067d8d0e25Snbeams Ceed ceed; 10077d8d0e25Snbeams ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr); 10087d8d0e25Snbeams 10097d8d0e25Snbeams CeedBasis_Hip_shared *data; 10107d8d0e25Snbeams ierr = CeedBasisGetData(basis, &data); CeedChk(ierr); 10117d8d0e25Snbeams 10127d8d0e25Snbeams CeedChk_Hip(ceed, hipModuleUnload(data->module)); 10137d8d0e25Snbeams 10147d8d0e25Snbeams ierr = hipFree(data->d_qweight1d); CeedChk_Hip(ceed, ierr); 10157d8d0e25Snbeams ierr = hipFree(data->d_interp1d); CeedChk_Hip(ceed, ierr); 10167d8d0e25Snbeams ierr = hipFree(data->d_grad1d); CeedChk_Hip(ceed, ierr); 10177d8d0e25Snbeams ierr = hipFree(data->d_collograd1d); CeedChk_Hip(ceed, ierr); 10187d8d0e25Snbeams 10197d8d0e25Snbeams ierr = CeedFree(&data); CeedChk(ierr); 10207d8d0e25Snbeams 10217d8d0e25Snbeams return 0; 10227d8d0e25Snbeams } 10237d8d0e25Snbeams 10247d8d0e25Snbeams //------------------------------------------------------------------------------ 10257d8d0e25Snbeams // Create tensor basis 10267d8d0e25Snbeams //------------------------------------------------------------------------------ 10277d8d0e25Snbeams int CeedBasisCreateTensorH1_Hip_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d, 10287d8d0e25Snbeams const CeedScalar *interp1d, 10297d8d0e25Snbeams const CeedScalar *grad1d, 10307d8d0e25Snbeams const CeedScalar *qref1d, 10317d8d0e25Snbeams const CeedScalar *qweight1d, 10327d8d0e25Snbeams CeedBasis basis) { 10337d8d0e25Snbeams int ierr; 10347d8d0e25Snbeams Ceed ceed; 10357d8d0e25Snbeams ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr); 10367d8d0e25Snbeams CeedBasis_Hip_shared *data; 10377d8d0e25Snbeams ierr = CeedCalloc(1, &data); CeedChk(ierr); 10387d8d0e25Snbeams 10397d8d0e25Snbeams // Copy basis data to GPU 10407d8d0e25Snbeams const CeedInt qBytes = Q1d * sizeof(CeedScalar); 10417d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Hip(ceed, ierr); 10427d8d0e25Snbeams ierr = hipMemcpy(data->d_qweight1d, qweight1d, qBytes, 10437d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 10447d8d0e25Snbeams 10457d8d0e25Snbeams const CeedInt iBytes = qBytes * P1d; 10467d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Hip(ceed, ierr); 10477d8d0e25Snbeams ierr = hipMemcpy(data->d_interp1d, interp1d, iBytes, 10487d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 10497d8d0e25Snbeams 10507d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Hip(ceed, ierr); 10517d8d0e25Snbeams ierr = hipMemcpy(data->d_grad1d, grad1d, iBytes, 10527d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 10537d8d0e25Snbeams 10547d8d0e25Snbeams // Compute collocated gradient and copy to GPU 10557d8d0e25Snbeams data->d_collograd1d = NULL; 10567d8d0e25Snbeams if (dim == 3 && Q1d >= P1d) { 10577d8d0e25Snbeams CeedScalar *collograd1d; 10587d8d0e25Snbeams ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChk(ierr); 10597d8d0e25Snbeams ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChk(ierr); 10607d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_collograd1d, qBytes * Q1d); 10617d8d0e25Snbeams CeedChk_Hip(ceed, ierr); 10627d8d0e25Snbeams ierr = hipMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d, 10637d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 10647d8d0e25Snbeams ierr = CeedFree(&collograd1d); CeedChk(ierr); 10657d8d0e25Snbeams } 10667d8d0e25Snbeams 1067*9e31c45bSnbeams // Set number of threads per block for basis kernels 10687d8d0e25Snbeams CeedInt ncomp; 10697d8d0e25Snbeams ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr); 1070*9e31c45bSnbeams ierr = ComputeBasisThreadBlockSizes(dim, P1d, Q1d, ncomp, data->blksizes); 1071*9e31c45bSnbeams CeedChk(ierr); 1072*9e31c45bSnbeams 1073*9e31c45bSnbeams // Compile basis kernels 1074*9e31c45bSnbeams ierr = CeedCompileHip(ceed, kernelsShared, &data->module, 11, 10757d8d0e25Snbeams "Q1D", Q1d, 10767d8d0e25Snbeams "P1D", P1d, 10777d8d0e25Snbeams "T1D", CeedIntMax(Q1d, P1d), 10787d8d0e25Snbeams "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ? 10797d8d0e25Snbeams Q1d : P1d, dim), 10807d8d0e25Snbeams "BASIS_DIM", dim, 10817d8d0e25Snbeams "BASIS_NCOMP", ncomp, 10827d8d0e25Snbeams "BASIS_ELEMSIZE", CeedIntPow(P1d, dim), 1083*9e31c45bSnbeams "BASIS_NQPT", CeedIntPow(Q1d, dim), 1084*9e31c45bSnbeams "INTERP_BLKSIZE", data->blksizes[0], 1085*9e31c45bSnbeams "GRAD_BLKSIZE", data->blksizes[1], 1086*9e31c45bSnbeams "WEIGHT_BLKSIZE", data->blksizes[2] 10877d8d0e25Snbeams ); CeedChk(ierr); 10887d8d0e25Snbeams ierr = CeedGetKernelHip(ceed, data->module, "interp", &data->interp); 10897d8d0e25Snbeams CeedChk(ierr); 10907d8d0e25Snbeams ierr = CeedGetKernelHip(ceed, data->module, "grad", &data->grad); 10917d8d0e25Snbeams CeedChk(ierr); 10927d8d0e25Snbeams ierr = CeedGetKernelHip(ceed, data->module, "weight", &data->weight); 10937d8d0e25Snbeams CeedChk(ierr); 10947d8d0e25Snbeams 10957d8d0e25Snbeams ierr = CeedBasisSetData(basis, data); CeedChk(ierr); 10967d8d0e25Snbeams 10977d8d0e25Snbeams // Register backend functions 10987d8d0e25Snbeams ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply", 10997d8d0e25Snbeams CeedBasisApplyTensor_Hip_shared); 11007d8d0e25Snbeams CeedChk(ierr); 11017d8d0e25Snbeams ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", 11027d8d0e25Snbeams CeedBasisDestroy_Hip_shared); CeedChk(ierr); 11037d8d0e25Snbeams return 0; 11047d8d0e25Snbeams } 11057d8d0e25Snbeams //------------------------------------------------------------------------------ 1106