17d8d0e25Snbeams // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC. 27d8d0e25Snbeams // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707. 37d8d0e25Snbeams // All Rights reserved. See files LICENSE and NOTICE for details. 47d8d0e25Snbeams // 57d8d0e25Snbeams // This file is part of CEED, a collection of benchmarks, miniapps, software 67d8d0e25Snbeams // libraries and APIs for efficient high-order finite element and spectral 77d8d0e25Snbeams // element discretizations for exascale applications. For more information and 87d8d0e25Snbeams // source code availability see http://github.com/ceed. 97d8d0e25Snbeams // 107d8d0e25Snbeams // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 117d8d0e25Snbeams // a collaborative effort of two U.S. Department of Energy organizations (Office 127d8d0e25Snbeams // of Science and the National Nuclear Security Administration) responsible for 137d8d0e25Snbeams // the planning and preparation of a capable exascale ecosystem, including 147d8d0e25Snbeams // software, applications, hardware, advanced system engineering and early 157d8d0e25Snbeams // testbed platforms, in support of the nation's exascale computing imperative. 167d8d0e25Snbeams 177d8d0e25Snbeams #include "ceed-hip-shared.h" 187d8d0e25Snbeams #include "../hip/ceed-hip-compile.h" 197d8d0e25Snbeams 207d8d0e25Snbeams //------------------------------------------------------------------------------ 217d8d0e25Snbeams // Shared mem kernels 227d8d0e25Snbeams //------------------------------------------------------------------------------ 237d8d0e25Snbeams // *INDENT-OFF* 247d8d0e25Snbeams static const char *kernelsShared = QUOTE( 257d8d0e25Snbeams 267d8d0e25Snbeams //------------------------------------------------------------------------------ 277d8d0e25Snbeams // Sum input into output 287d8d0e25Snbeams //------------------------------------------------------------------------------ 297d8d0e25Snbeams inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) { 307d8d0e25Snbeams for (int i = 0; i < P1D; i++) 317d8d0e25Snbeams r_V[i] += r_U[i]; 327d8d0e25Snbeams } 337d8d0e25Snbeams 347d8d0e25Snbeams //------------------------------------------------------------------------------ 357d8d0e25Snbeams // 1D 367d8d0e25Snbeams //------------------------------------------------------------------------------ 377d8d0e25Snbeams 387d8d0e25Snbeams //------------------------------------------------------------------------------ 397d8d0e25Snbeams // Read DoFs 407d8d0e25Snbeams //------------------------------------------------------------------------------ 417d8d0e25Snbeams inline __device__ void readDofs1d(const int elem, const int tidx, 427d8d0e25Snbeams const int tidy, const int tidz,const int comp, 437d8d0e25Snbeams const int nelem, const CeedScalar *d_U, 447d8d0e25Snbeams CeedScalar *slice) { 457d8d0e25Snbeams for (int i = 0; i < P1D; i++) 467d8d0e25Snbeams slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem]; 477d8d0e25Snbeams for (int i = P1D; i < Q1D; i++) 487d8d0e25Snbeams slice[i + tidz*T1D] = 0.0; 497d8d0e25Snbeams } 507d8d0e25Snbeams 517d8d0e25Snbeams //------------------------------------------------------------------------------ 527d8d0e25Snbeams // Write DoFs 537d8d0e25Snbeams //------------------------------------------------------------------------------ 547d8d0e25Snbeams inline __device__ void writeDofs1d(const int elem, const int tidx, 557d8d0e25Snbeams const int tidy, const int comp, 567d8d0e25Snbeams const int nelem, const CeedScalar &r_V, 577d8d0e25Snbeams CeedScalar *d_V) { 587d8d0e25Snbeams if (tidx<P1D) 597d8d0e25Snbeams d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V; 607d8d0e25Snbeams } 617d8d0e25Snbeams 627d8d0e25Snbeams //------------------------------------------------------------------------------ 637d8d0e25Snbeams // Read quadrature point data 647d8d0e25Snbeams //------------------------------------------------------------------------------ 657d8d0e25Snbeams inline __device__ void readQuads1d(const int elem, const int tidx, 667d8d0e25Snbeams const int tidy, const int tidz, const int comp, 677d8d0e25Snbeams const int dim, const int nelem, 687d8d0e25Snbeams const CeedScalar *d_U, CeedScalar *slice) { 697d8d0e25Snbeams for (int i = 0; i < Q1D; i++) 707d8d0e25Snbeams slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem + 717d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D]; 727d8d0e25Snbeams for (int i = Q1D; i < P1D; i++) 737d8d0e25Snbeams slice[i + tidz*T1D] = 0.0; 747d8d0e25Snbeams } 757d8d0e25Snbeams 767d8d0e25Snbeams //------------------------------------------------------------------------------ 777d8d0e25Snbeams // Write quadrature point data 787d8d0e25Snbeams //------------------------------------------------------------------------------ 797d8d0e25Snbeams inline __device__ void writeQuads1d(const int elem, const int tidx, 807d8d0e25Snbeams const int tidy, const int comp, 817d8d0e25Snbeams const int dim, const int nelem, 827d8d0e25Snbeams const CeedScalar &r_V, CeedScalar *d_V) { 837d8d0e25Snbeams if (tidx<Q1D) 847d8d0e25Snbeams d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V; 857d8d0e25Snbeams } 867d8d0e25Snbeams 877d8d0e25Snbeams //------------------------------------------------------------------------------ 887d8d0e25Snbeams // 1D tensor contraction 897d8d0e25Snbeams //------------------------------------------------------------------------------ 907d8d0e25Snbeams inline __device__ void ContractX1d(CeedScalar *slice, const int tidx, 917d8d0e25Snbeams const int tidy, const int tidz, 927d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, 937d8d0e25Snbeams CeedScalar &V) { 947d8d0e25Snbeams V = 0.0; 957d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 967d8d0e25Snbeams V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction 977d8d0e25Snbeams } 987d8d0e25Snbeams 997d8d0e25Snbeams //------------------------------------------------------------------------------ 1007d8d0e25Snbeams // 1D transpose tensor contraction 1017d8d0e25Snbeams //------------------------------------------------------------------------------ 1027d8d0e25Snbeams inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx, 1037d8d0e25Snbeams const int tidy, const int tidz, 1047d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 1057d8d0e25Snbeams V = 0.0; 1067d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 1077d8d0e25Snbeams V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction 1087d8d0e25Snbeams } 1097d8d0e25Snbeams 1107d8d0e25Snbeams //------------------------------------------------------------------------------ 1117d8d0e25Snbeams // 1D interpolate to quadrature points 1127d8d0e25Snbeams //------------------------------------------------------------------------------ 1137d8d0e25Snbeams inline __device__ void interp1d(const CeedInt nelem, const int transpose, 1147d8d0e25Snbeams const CeedScalar *c_B, 1157d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 1167d8d0e25Snbeams CeedScalar *__restrict__ d_V, 1177d8d0e25Snbeams CeedScalar *slice) { 1187d8d0e25Snbeams CeedScalar r_V; 1197d8d0e25Snbeams CeedScalar r_t; 1207d8d0e25Snbeams 1217d8d0e25Snbeams const int tidx = threadIdx.x; 1227d8d0e25Snbeams const int tidy = threadIdx.y; 1237d8d0e25Snbeams const int tidz = threadIdx.z; 1247d8d0e25Snbeams 1257d8d0e25Snbeams 1267d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 1277d8d0e25Snbeams elem += gridDim.x*blockDim.z) { 1287d8d0e25Snbeams for (int comp = 0; comp < BASIS_NCOMP; comp++) { 1297d8d0e25Snbeams if (!transpose) { 1307d8d0e25Snbeams readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 1317d8d0e25Snbeams ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 1327d8d0e25Snbeams writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 1337d8d0e25Snbeams } else { 1347d8d0e25Snbeams readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice); 1357d8d0e25Snbeams ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 1367d8d0e25Snbeams writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 1377d8d0e25Snbeams } 1387d8d0e25Snbeams } 1397d8d0e25Snbeams } 1407d8d0e25Snbeams } 1417d8d0e25Snbeams 1427d8d0e25Snbeams //------------------------------------------------------------------------------ 1437d8d0e25Snbeams // 1D derivatives at quadrature points 1447d8d0e25Snbeams //------------------------------------------------------------------------------ 1457d8d0e25Snbeams inline __device__ void grad1d(const CeedInt nelem, const int transpose, 1467d8d0e25Snbeams const CeedScalar *c_B, const CeedScalar *c_G, 1477d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 1487d8d0e25Snbeams CeedScalar *__restrict__ d_V, 1497d8d0e25Snbeams CeedScalar *slice) { 1507d8d0e25Snbeams CeedScalar r_U; 1517d8d0e25Snbeams CeedScalar r_V; 1527d8d0e25Snbeams 1537d8d0e25Snbeams const int tidx = threadIdx.x; 1547d8d0e25Snbeams const int tidy = threadIdx.y; 1557d8d0e25Snbeams const int tidz = threadIdx.z; 1567d8d0e25Snbeams int dim; 1577d8d0e25Snbeams 1587d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 1597d8d0e25Snbeams elem += gridDim.x*blockDim.z) { 1607d8d0e25Snbeams for(int comp = 0; comp < BASIS_NCOMP; comp++) { 1617d8d0e25Snbeams if (!transpose) { 1627d8d0e25Snbeams readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 1637d8d0e25Snbeams ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 1647d8d0e25Snbeams dim = 0; 1657d8d0e25Snbeams writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 1667d8d0e25Snbeams } else { 1677d8d0e25Snbeams dim = 0; 1687d8d0e25Snbeams readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice); 1697d8d0e25Snbeams ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 1707d8d0e25Snbeams writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 1717d8d0e25Snbeams } 1727d8d0e25Snbeams } 1737d8d0e25Snbeams } 1747d8d0e25Snbeams } 1757d8d0e25Snbeams 1767d8d0e25Snbeams //------------------------------------------------------------------------------ 1777d8d0e25Snbeams // 1D Quadrature weights 1787d8d0e25Snbeams //------------------------------------------------------------------------------ 1797d8d0e25Snbeams __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d, 1807d8d0e25Snbeams CeedScalar *w) { 1817d8d0e25Snbeams const int tid = threadIdx.x; 1827d8d0e25Snbeams const CeedScalar weight = qweight1d[tid]; 1837d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem; 1847d8d0e25Snbeams elem += gridDim.x*blockDim.y) { 1857d8d0e25Snbeams const int ind = elem*Q1D + tid; 1867d8d0e25Snbeams w[ind] = weight; 1877d8d0e25Snbeams } 1887d8d0e25Snbeams } 1897d8d0e25Snbeams 1907d8d0e25Snbeams //------------------------------------------------------------------------------ 1917d8d0e25Snbeams // 2D 1927d8d0e25Snbeams //------------------------------------------------------------------------------ 1937d8d0e25Snbeams 1947d8d0e25Snbeams //------------------------------------------------------------------------------ 1957d8d0e25Snbeams // Read DoFs 1967d8d0e25Snbeams //------------------------------------------------------------------------------ 1977d8d0e25Snbeams inline __device__ void readDofs2d(const int elem, const int tidx, 1987d8d0e25Snbeams const int tidy, const int comp, 1997d8d0e25Snbeams const int nelem, const CeedScalar *d_U, 2007d8d0e25Snbeams CeedScalar &U) { 2017d8d0e25Snbeams U = (tidx<P1D && tidy<P1D) ? 2027d8d0e25Snbeams d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0; 2037d8d0e25Snbeams } 2047d8d0e25Snbeams 2057d8d0e25Snbeams //------------------------------------------------------------------------------ 2067d8d0e25Snbeams // Write DoFs 2077d8d0e25Snbeams //------------------------------------------------------------------------------ 2087d8d0e25Snbeams inline __device__ void writeDofs2d(const int elem, const int tidx, 2097d8d0e25Snbeams const int tidy, const int comp, 2107d8d0e25Snbeams const int nelem, const CeedScalar &r_V, 2117d8d0e25Snbeams CeedScalar *d_V) { 2127d8d0e25Snbeams if (tidx<P1D && tidy<P1D) 2137d8d0e25Snbeams d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V; 2147d8d0e25Snbeams } 2157d8d0e25Snbeams 2167d8d0e25Snbeams //------------------------------------------------------------------------------ 2177d8d0e25Snbeams // Read quadrature point data 2187d8d0e25Snbeams //------------------------------------------------------------------------------ 2197d8d0e25Snbeams inline __device__ void readQuads2d(const int elem, const int tidx, 2207d8d0e25Snbeams const int tidy, const int comp, 2217d8d0e25Snbeams const int dim, const int nelem, 2227d8d0e25Snbeams const CeedScalar *d_U, CeedScalar &U ) { 2237d8d0e25Snbeams U = (tidx<Q1D && tidy<Q1D) ? 2247d8d0e25Snbeams d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 2257d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0; 2267d8d0e25Snbeams } 2277d8d0e25Snbeams 2287d8d0e25Snbeams //------------------------------------------------------------------------------ 2297d8d0e25Snbeams // Write quadrature point data 2307d8d0e25Snbeams //------------------------------------------------------------------------------ 2317d8d0e25Snbeams inline __device__ void writeQuads2d(const int elem, const int tidx, 2327d8d0e25Snbeams const int tidy, const int comp, 2337d8d0e25Snbeams const int dim, const int nelem, 2347d8d0e25Snbeams const CeedScalar &r_V, CeedScalar *d_V) { 2357d8d0e25Snbeams if (tidx<Q1D && tidy<Q1D) 2367d8d0e25Snbeams d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 2377d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V; 2387d8d0e25Snbeams } 2397d8d0e25Snbeams 2407d8d0e25Snbeams //------------------------------------------------------------------------------ 2417d8d0e25Snbeams // 2D tensor contraction x 2427d8d0e25Snbeams //------------------------------------------------------------------------------ 2437d8d0e25Snbeams inline __device__ void ContractX2d(CeedScalar *slice, const int tidx, 2447d8d0e25Snbeams const int tidy, const int tidz, 2457d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, 2467d8d0e25Snbeams CeedScalar &V) { 2477d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 2487d8d0e25Snbeams __syncthreads(); 2497d8d0e25Snbeams V = 0.0; 2507d8d0e25Snbeams if (tidx < Q1D) 2517d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 2527d8d0e25Snbeams V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 2537d8d0e25Snbeams __syncthreads(); 2547d8d0e25Snbeams } 2557d8d0e25Snbeams 2567d8d0e25Snbeams //------------------------------------------------------------------------------ 2577d8d0e25Snbeams // 2D tensor contraction y 2587d8d0e25Snbeams //------------------------------------------------------------------------------ 2597d8d0e25Snbeams inline __device__ void ContractY2d(CeedScalar *slice, const int tidx, 2607d8d0e25Snbeams const int tidy, const int tidz, 2617d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, 2627d8d0e25Snbeams CeedScalar &V) { 2637d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 2647d8d0e25Snbeams __syncthreads(); 2657d8d0e25Snbeams V = 0.0; 2667d8d0e25Snbeams if (tidy < Q1D) 2677d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 2687d8d0e25Snbeams V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 2697d8d0e25Snbeams __syncthreads(); 2707d8d0e25Snbeams } 2717d8d0e25Snbeams 2727d8d0e25Snbeams //------------------------------------------------------------------------------ 2737d8d0e25Snbeams // 2D transpose tensor contraction y 2747d8d0e25Snbeams //------------------------------------------------------------------------------ 2757d8d0e25Snbeams inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx, 2767d8d0e25Snbeams const int tidy, const int tidz, 2777d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 2787d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 2797d8d0e25Snbeams __syncthreads(); 2807d8d0e25Snbeams V = 0.0; 2817d8d0e25Snbeams if (tidy < P1D) 2827d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 2837d8d0e25Snbeams V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 2847d8d0e25Snbeams __syncthreads(); 2857d8d0e25Snbeams } 2867d8d0e25Snbeams 2877d8d0e25Snbeams //------------------------------------------------------------------------------ 2887d8d0e25Snbeams // 2D transpose tensor contraction x 2897d8d0e25Snbeams //------------------------------------------------------------------------------ 2907d8d0e25Snbeams inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx, 2917d8d0e25Snbeams const int tidy, const int tidz, 2927d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 2937d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 2947d8d0e25Snbeams __syncthreads(); 2957d8d0e25Snbeams V = 0.0; 2967d8d0e25Snbeams if (tidx < P1D) 2977d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 2987d8d0e25Snbeams V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 2997d8d0e25Snbeams __syncthreads(); 3007d8d0e25Snbeams } 3017d8d0e25Snbeams 3027d8d0e25Snbeams //------------------------------------------------------------------------------ 3037d8d0e25Snbeams // 2D interpolate to quadrature points 3047d8d0e25Snbeams //------------------------------------------------------------------------------ 3057d8d0e25Snbeams inline __device__ void interp2d(const CeedInt nelem, const int transpose, 3067d8d0e25Snbeams const CeedScalar *c_B, 3077d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 3087d8d0e25Snbeams CeedScalar *__restrict__ d_V, 3097d8d0e25Snbeams CeedScalar *slice) { 3107d8d0e25Snbeams CeedScalar r_V; 3117d8d0e25Snbeams CeedScalar r_t; 3127d8d0e25Snbeams 3137d8d0e25Snbeams const int tidx = threadIdx.x; 3147d8d0e25Snbeams const int tidy = threadIdx.y; 3157d8d0e25Snbeams const int tidz = threadIdx.z; 3167d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 3177d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 3187d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 3197d8d0e25Snbeams 3207d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 3217d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 3227d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 3237d8d0e25Snbeams r_V = 0.0; 3247d8d0e25Snbeams r_t = 0.0; 3257d8d0e25Snbeams if (!transpose) { 3267d8d0e25Snbeams readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V); 3277d8d0e25Snbeams ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 3287d8d0e25Snbeams ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 3297d8d0e25Snbeams writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 3307d8d0e25Snbeams } else { 3317d8d0e25Snbeams readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 3327d8d0e25Snbeams ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 3337d8d0e25Snbeams ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 3347d8d0e25Snbeams writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 3357d8d0e25Snbeams } 3367d8d0e25Snbeams } 3377d8d0e25Snbeams } 3387d8d0e25Snbeams 3397d8d0e25Snbeams //------------------------------------------------------------------------------ 3407d8d0e25Snbeams // 2D derivatives at quadrature points 3417d8d0e25Snbeams //------------------------------------------------------------------------------ 3427d8d0e25Snbeams inline __device__ void grad2d(const CeedInt nelem, const int transpose, 3437d8d0e25Snbeams const CeedScalar *c_B, const CeedScalar *c_G, 3447d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 3457d8d0e25Snbeams CeedScalar *__restrict__ d_V, CeedScalar *slice) { 3467d8d0e25Snbeams CeedScalar r_U; 3477d8d0e25Snbeams CeedScalar r_V; 3487d8d0e25Snbeams CeedScalar r_t; 3497d8d0e25Snbeams 3507d8d0e25Snbeams const int tidx = threadIdx.x; 3517d8d0e25Snbeams const int tidy = threadIdx.y; 3527d8d0e25Snbeams const int tidz = threadIdx.z; 3537d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 3547d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 3557d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 3567d8d0e25Snbeams int dim; 3577d8d0e25Snbeams 3587d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 3597d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 3607d8d0e25Snbeams if (!transpose) { 3617d8d0e25Snbeams readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U); 3627d8d0e25Snbeams ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 3637d8d0e25Snbeams ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 3647d8d0e25Snbeams dim = 0; 3657d8d0e25Snbeams writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 3667d8d0e25Snbeams ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 3677d8d0e25Snbeams ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 3687d8d0e25Snbeams dim = 1; 3697d8d0e25Snbeams writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 3707d8d0e25Snbeams } else { 3717d8d0e25Snbeams dim = 0; 3727d8d0e25Snbeams readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 3737d8d0e25Snbeams ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 3747d8d0e25Snbeams ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 3757d8d0e25Snbeams dim = 1; 3767d8d0e25Snbeams readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 3777d8d0e25Snbeams ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 3787d8d0e25Snbeams ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 3797d8d0e25Snbeams r_V += r_U; 3807d8d0e25Snbeams writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 3817d8d0e25Snbeams } 3827d8d0e25Snbeams } 3837d8d0e25Snbeams } 3847d8d0e25Snbeams 3857d8d0e25Snbeams //------------------------------------------------------------------------------ 3867d8d0e25Snbeams // 2D quadrature weights 3877d8d0e25Snbeams //------------------------------------------------------------------------------ 3887d8d0e25Snbeams __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d, 3897d8d0e25Snbeams CeedScalar *w) { 3907d8d0e25Snbeams const int i = threadIdx.x; 3917d8d0e25Snbeams const int j = threadIdx.y; 3927d8d0e25Snbeams const CeedScalar weight = qweight1d[i]*qweight1d[j]; 3937d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 3947d8d0e25Snbeams elem += gridDim.x*blockDim.z) { 3957d8d0e25Snbeams const int ind = elem*Q1D*Q1D + i + j*Q1D; 3967d8d0e25Snbeams w[ind] = weight; 3977d8d0e25Snbeams } 3987d8d0e25Snbeams } 3997d8d0e25Snbeams 4007d8d0e25Snbeams //------------------------------------------------------------------------------ 4017d8d0e25Snbeams // 3D 4027d8d0e25Snbeams //------------------------------------------------------------------------------ 4037d8d0e25Snbeams 4047d8d0e25Snbeams //------------------------------------------------------------------------------ 4057d8d0e25Snbeams // Read DoFs 4067d8d0e25Snbeams //------------------------------------------------------------------------------ 4077d8d0e25Snbeams inline __device__ void readDofs3d(const int elem, const int tidx, 4087d8d0e25Snbeams const int tidy, const int comp, 4097d8d0e25Snbeams const int nelem, const CeedScalar *d_U, 4107d8d0e25Snbeams CeedScalar *r_U) { 4117d8d0e25Snbeams for (int i = 0; i < P1D; i++) 4127d8d0e25Snbeams r_U[i] = (tidx < P1D && tidy < P1D) ? 4137d8d0e25Snbeams d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D + 4147d8d0e25Snbeams comp*P1D*P1D*P1D*nelem] : 0.0; 4157d8d0e25Snbeams for (int i = P1D; i < Q1D; i++) 4167d8d0e25Snbeams r_U[i] = 0.0; 4177d8d0e25Snbeams } 4187d8d0e25Snbeams 4197d8d0e25Snbeams //------------------------------------------------------------------------------ 4207d8d0e25Snbeams // Write DoFs 4217d8d0e25Snbeams //------------------------------------------------------------------------------ 4227d8d0e25Snbeams inline __device__ void writeDofs3d(const int elem, const int tidx, 4237d8d0e25Snbeams const int tidy, const int comp, 4247d8d0e25Snbeams const int nelem, const CeedScalar *r_V, 4257d8d0e25Snbeams CeedScalar *d_V) { 4267d8d0e25Snbeams if (tidx < P1D && tidy < P1D) { 4277d8d0e25Snbeams for (int i = 0; i < P1D; i++) 4287d8d0e25Snbeams d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D + 4297d8d0e25Snbeams comp*P1D*P1D*P1D*nelem] = r_V[i]; 4307d8d0e25Snbeams } 4317d8d0e25Snbeams } 4327d8d0e25Snbeams 4337d8d0e25Snbeams //------------------------------------------------------------------------------ 4347d8d0e25Snbeams // Read quadrature point data 4357d8d0e25Snbeams //------------------------------------------------------------------------------ 4367d8d0e25Snbeams inline __device__ void readQuads3d(const int elem, const int tidx, 4377d8d0e25Snbeams const int tidy, const int comp, 4387d8d0e25Snbeams const int dim, const int nelem, 4397d8d0e25Snbeams const CeedScalar *d_U, CeedScalar *r_U) { 4407d8d0e25Snbeams for (int i = 0; i < Q1D; i++) 4417d8d0e25Snbeams r_U[i] = (tidx < Q1D && tidy < Q1D) ? 4427d8d0e25Snbeams d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + 4437d8d0e25Snbeams comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0; 4447d8d0e25Snbeams for (int i = Q1D; i < P1D; i++) 4457d8d0e25Snbeams r_U[i] = 0.0; 4467d8d0e25Snbeams } 4477d8d0e25Snbeams 4487d8d0e25Snbeams //------------------------------------------------------------------------------ 4497d8d0e25Snbeams // Write quadrature point data 4507d8d0e25Snbeams //------------------------------------------------------------------------------ 4517d8d0e25Snbeams inline __device__ void writeQuads3d(const int elem, const int tidx, 4527d8d0e25Snbeams const int tidy, const int comp, 4537d8d0e25Snbeams const int dim, const int nelem, 4547d8d0e25Snbeams const CeedScalar *r_V, CeedScalar *d_V) { 4557d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) { 4567d8d0e25Snbeams for (int i = 0; i < Q1D; i++) 4577d8d0e25Snbeams d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem + 4587d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i]; 4597d8d0e25Snbeams } 4607d8d0e25Snbeams } 4617d8d0e25Snbeams 4627d8d0e25Snbeams //------------------------------------------------------------------------------ 4637d8d0e25Snbeams // 3D tensor contract x 4647d8d0e25Snbeams //------------------------------------------------------------------------------ 4657d8d0e25Snbeams inline __device__ void ContractX3d(CeedScalar *slice, const int tidx, 4667d8d0e25Snbeams const int tidy, const int tidz, 4677d8d0e25Snbeams const CeedScalar *U, 4687d8d0e25Snbeams const CeedScalar *B, 4697d8d0e25Snbeams CeedScalar *V) { 4707d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 4717d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 4727d8d0e25Snbeams __syncthreads(); 4737d8d0e25Snbeams V[k] = 0.0; 4747d8d0e25Snbeams if (tidx < Q1D && tidy < P1D) 4757d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 4767d8d0e25Snbeams V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 4777d8d0e25Snbeams __syncthreads(); 4787d8d0e25Snbeams } 4797d8d0e25Snbeams } 4807d8d0e25Snbeams 4817d8d0e25Snbeams //------------------------------------------------------------------------------ 4827d8d0e25Snbeams // 3D tensor contract y 4837d8d0e25Snbeams //------------------------------------------------------------------------------ 4847d8d0e25Snbeams inline __device__ void ContractY3d(CeedScalar *slice, const int tidx, 4857d8d0e25Snbeams const int tidy, const int tidz, 4867d8d0e25Snbeams const CeedScalar *U, 4877d8d0e25Snbeams const CeedScalar *B, 4887d8d0e25Snbeams CeedScalar *V) { 4897d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 4907d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 4917d8d0e25Snbeams __syncthreads(); 4927d8d0e25Snbeams V[k] = 0.0; 4937d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) 4947d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 4957d8d0e25Snbeams V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 4967d8d0e25Snbeams __syncthreads(); 4977d8d0e25Snbeams } 4987d8d0e25Snbeams } 4997d8d0e25Snbeams 5007d8d0e25Snbeams //------------------------------------------------------------------------------ 5017d8d0e25Snbeams // 3D tensor contract z 5027d8d0e25Snbeams //------------------------------------------------------------------------------ 5037d8d0e25Snbeams inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx, 5047d8d0e25Snbeams const int tidy, const int tidz, 5057d8d0e25Snbeams const CeedScalar *U, 5067d8d0e25Snbeams const CeedScalar *B, 5077d8d0e25Snbeams CeedScalar *V) { 5087d8d0e25Snbeams for (int k = 0; k < Q1D; ++k) { 5097d8d0e25Snbeams V[k] = 0.0; 5107d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) 5117d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 5127d8d0e25Snbeams V[k] += B[i + k*P1D] * U[i]; // Contract z direction 5137d8d0e25Snbeams } 5147d8d0e25Snbeams for (int k = Q1D; k < P1D; ++k) 5157d8d0e25Snbeams V[k] = 0.0; 5167d8d0e25Snbeams } 5177d8d0e25Snbeams 5187d8d0e25Snbeams //------------------------------------------------------------------------------ 5197d8d0e25Snbeams // 3D transpose tensor contract z 5207d8d0e25Snbeams //------------------------------------------------------------------------------ 5217d8d0e25Snbeams inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx, 5227d8d0e25Snbeams const int tidy, const int tidz, 5237d8d0e25Snbeams const CeedScalar *U, 5247d8d0e25Snbeams const CeedScalar *B, 5257d8d0e25Snbeams CeedScalar *V) { 5267d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 5277d8d0e25Snbeams V[k] = 0.0; 5287d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) 5297d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 5307d8d0e25Snbeams V[k] += B[k + i*P1D] * U[i]; // Contract z direction 5317d8d0e25Snbeams } 5327d8d0e25Snbeams for (int k = P1D; k < Q1D; ++k) 5337d8d0e25Snbeams V[k] = 0.0; 5347d8d0e25Snbeams } 5357d8d0e25Snbeams 5367d8d0e25Snbeams //------------------------------------------------------------------------------ 5377d8d0e25Snbeams // 3D transpose tensor contract y 5387d8d0e25Snbeams //------------------------------------------------------------------------------ 5397d8d0e25Snbeams inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx, 5407d8d0e25Snbeams const int tidy, const int tidz, 5417d8d0e25Snbeams const CeedScalar *U, 5427d8d0e25Snbeams const CeedScalar *B, 5437d8d0e25Snbeams CeedScalar *V) { 5447d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 5457d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 5467d8d0e25Snbeams __syncthreads(); 5477d8d0e25Snbeams V[k] = 0.0; 5487d8d0e25Snbeams if (tidx < Q1D && tidy < P1D) 5497d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 5507d8d0e25Snbeams V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 5517d8d0e25Snbeams __syncthreads(); 5527d8d0e25Snbeams } 5537d8d0e25Snbeams } 5547d8d0e25Snbeams 5557d8d0e25Snbeams //------------------------------------------------------------------------------ 5567d8d0e25Snbeams // 3D transpose tensor contract x 5577d8d0e25Snbeams //------------------------------------------------------------------------------ 5587d8d0e25Snbeams inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx, 5597d8d0e25Snbeams const int tidy, const int tidz, 5607d8d0e25Snbeams const CeedScalar *U, 5617d8d0e25Snbeams const CeedScalar *B, 5627d8d0e25Snbeams CeedScalar *V) { 5637d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 5647d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 5657d8d0e25Snbeams __syncthreads(); 5667d8d0e25Snbeams V[k] = 0.0; 5677d8d0e25Snbeams if (tidx < P1D && tidy < P1D) 5687d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 5697d8d0e25Snbeams V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 5707d8d0e25Snbeams __syncthreads(); 5717d8d0e25Snbeams } 5727d8d0e25Snbeams } 5737d8d0e25Snbeams 5747d8d0e25Snbeams //------------------------------------------------------------------------------ 5757d8d0e25Snbeams // 3D interpolate to quadrature points 5767d8d0e25Snbeams //------------------------------------------------------------------------------ 5777d8d0e25Snbeams inline __device__ void interp3d(const CeedInt nelem, const int transpose, 5787d8d0e25Snbeams const CeedScalar *c_B, 5797d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 5807d8d0e25Snbeams CeedScalar *__restrict__ d_V, 5817d8d0e25Snbeams CeedScalar *slice) { 5827d8d0e25Snbeams CeedScalar r_V[T1D]; 5837d8d0e25Snbeams CeedScalar r_t[T1D]; 5847d8d0e25Snbeams 5857d8d0e25Snbeams const int tidx = threadIdx.x; 5867d8d0e25Snbeams const int tidy = threadIdx.y; 5877d8d0e25Snbeams const int tidz = threadIdx.z; 5887d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 5897d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 5907d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 5917d8d0e25Snbeams 5927d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 5937d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 5947d8d0e25Snbeams for (int i = 0; i < T1D; ++i) { 5957d8d0e25Snbeams r_V[i] = 0.0; 5967d8d0e25Snbeams r_t[i] = 0.0; 5977d8d0e25Snbeams } 5987d8d0e25Snbeams if (!transpose) { 5997d8d0e25Snbeams readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V); 6007d8d0e25Snbeams ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 6017d8d0e25Snbeams ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 6027d8d0e25Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 6037d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V); 6047d8d0e25Snbeams } else { 6057d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 6067d8d0e25Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 6077d8d0e25Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 6087d8d0e25Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 6097d8d0e25Snbeams writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V); 6107d8d0e25Snbeams } 6117d8d0e25Snbeams } 6127d8d0e25Snbeams } 6137d8d0e25Snbeams 6147d8d0e25Snbeams //------------------------------------------------------------------------------ 6157d8d0e25Snbeams // 3D derivatives at quadrature points 6167d8d0e25Snbeams //------------------------------------------------------------------------------ 6177d8d0e25Snbeams inline __device__ void grad3d(const CeedInt nelem, const int transpose, 6187d8d0e25Snbeams const CeedScalar *c_B, const CeedScalar *c_G, 6197d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 6207d8d0e25Snbeams CeedScalar *__restrict__ d_V, 6217d8d0e25Snbeams CeedScalar *slice) { 6227d8d0e25Snbeams // Use P1D for one of these 6237d8d0e25Snbeams CeedScalar r_U[T1D]; 6247d8d0e25Snbeams CeedScalar r_V[T1D]; 6257d8d0e25Snbeams CeedScalar r_t[T1D]; 6267d8d0e25Snbeams 6277d8d0e25Snbeams const int tidx = threadIdx.x; 6287d8d0e25Snbeams const int tidy = threadIdx.y; 6297d8d0e25Snbeams const int tidz = threadIdx.z; 6307d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 6317d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 6327d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 6337d8d0e25Snbeams int dim; 6347d8d0e25Snbeams 6357d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 6367d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 6377d8d0e25Snbeams for (int i = 0; i < T1D; ++i) { 6387d8d0e25Snbeams r_U[i] = 0.0; 6397d8d0e25Snbeams r_V[i] = 0.0; 6407d8d0e25Snbeams r_t[i] = 0.0; 6417d8d0e25Snbeams } 6427d8d0e25Snbeams if (!transpose) { 6437d8d0e25Snbeams readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U); 6447d8d0e25Snbeams ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 6457d8d0e25Snbeams ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 6467d8d0e25Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 6477d8d0e25Snbeams dim = 0; 6487d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 6497d8d0e25Snbeams ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V); 6507d8d0e25Snbeams ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t); 6517d8d0e25Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 6527d8d0e25Snbeams dim = 1; 6537d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 6547d8d0e25Snbeams ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V); 6557d8d0e25Snbeams ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 6567d8d0e25Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 6577d8d0e25Snbeams dim = 2; 6587d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 6597d8d0e25Snbeams } else { 6607d8d0e25Snbeams dim = 0; 6617d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 6627d8d0e25Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 6637d8d0e25Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 6647d8d0e25Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 6657d8d0e25Snbeams dim = 1; 6667d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 6677d8d0e25Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 6687d8d0e25Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U); 6697d8d0e25Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 6707d8d0e25Snbeams add(r_V, r_t); 6717d8d0e25Snbeams dim = 2; 6727d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 6737d8d0e25Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 6747d8d0e25Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 6757d8d0e25Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 6767d8d0e25Snbeams add(r_V, r_t); 6777d8d0e25Snbeams writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V); 6787d8d0e25Snbeams } 6797d8d0e25Snbeams } 6807d8d0e25Snbeams } 6817d8d0e25Snbeams 6827d8d0e25Snbeams //------------------------------------------------------------------------------ 6837d8d0e25Snbeams // 3D quadrature weights 6847d8d0e25Snbeams //------------------------------------------------------------------------------ 6857d8d0e25Snbeams __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d, 6867d8d0e25Snbeams CeedScalar *w) { 6877d8d0e25Snbeams const int i = threadIdx.x; 6887d8d0e25Snbeams const int j = threadIdx.y; 6897d8d0e25Snbeams const int k = threadIdx.z; 6907d8d0e25Snbeams const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k]; 6917d8d0e25Snbeams for (int e = blockIdx.x; e < nelem; e += gridDim.x) { 6927d8d0e25Snbeams const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D; 6937d8d0e25Snbeams w[ind] = weight; 6947d8d0e25Snbeams } 6957d8d0e25Snbeams } 6967d8d0e25Snbeams 6977d8d0e25Snbeams 6987d8d0e25Snbeams //------------------------------------------------------------------------------ 6997d8d0e25Snbeams // Basis kernels 7007d8d0e25Snbeams //------------------------------------------------------------------------------ 7017d8d0e25Snbeams 7027d8d0e25Snbeams //------------------------------------------------------------------------------ 7037d8d0e25Snbeams // Interp kernel by dim 7047d8d0e25Snbeams //------------------------------------------------------------------------------ 7057d8d0e25Snbeams extern "C" __global__ void interp(const CeedInt nelem, const int transpose, 7067d8d0e25Snbeams const CeedScalar *c_B, 7077d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 7087d8d0e25Snbeams CeedScalar *__restrict__ d_V) { 7097d8d0e25Snbeams HIP_DYNAMIC_SHARED( double, slice) 7107d8d0e25Snbeams if (BASIS_DIM == 1) { 7117d8d0e25Snbeams interp1d(nelem, transpose, c_B, d_U, d_V, slice); 7127d8d0e25Snbeams } else if (BASIS_DIM == 2) { 7137d8d0e25Snbeams interp2d(nelem, transpose, c_B, d_U, d_V, slice); 7147d8d0e25Snbeams } else if (BASIS_DIM == 3) { 7157d8d0e25Snbeams interp3d(nelem, transpose, c_B, d_U, d_V, slice); 7167d8d0e25Snbeams } 7177d8d0e25Snbeams } 7187d8d0e25Snbeams 7197d8d0e25Snbeams //------------------------------------------------------------------------------ 7207d8d0e25Snbeams // Grad kernel by dim 7217d8d0e25Snbeams //------------------------------------------------------------------------------ 7227d8d0e25Snbeams extern "C" __global__ void grad(const CeedInt nelem, const int transpose, 7237d8d0e25Snbeams const CeedScalar *c_B, const CeedScalar *c_G, 7247d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 7257d8d0e25Snbeams CeedScalar *__restrict__ d_V) { 7267d8d0e25Snbeams HIP_DYNAMIC_SHARED( double, slice) 7277d8d0e25Snbeams if (BASIS_DIM == 1) { 7287d8d0e25Snbeams grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 7297d8d0e25Snbeams } else if (BASIS_DIM == 2) { 7307d8d0e25Snbeams grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 7317d8d0e25Snbeams } else if (BASIS_DIM == 3) { 7327d8d0e25Snbeams grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 7337d8d0e25Snbeams } 7347d8d0e25Snbeams } 7357d8d0e25Snbeams 7367d8d0e25Snbeams //------------------------------------------------------------------------------ 7377d8d0e25Snbeams // Weight kernels by dim 7387d8d0e25Snbeams //------------------------------------------------------------------------------ 7397d8d0e25Snbeams extern "C" __global__ void weight(const CeedInt nelem, 7407d8d0e25Snbeams const CeedScalar *__restrict__ qweight1d, 7417d8d0e25Snbeams CeedScalar *__restrict__ v) { 7427d8d0e25Snbeams if (BASIS_DIM == 1) { 7437d8d0e25Snbeams weight1d(nelem, qweight1d, v); 7447d8d0e25Snbeams } else if (BASIS_DIM == 2) { 7457d8d0e25Snbeams weight2d(nelem, qweight1d, v); 7467d8d0e25Snbeams } else if (BASIS_DIM == 3) { 7477d8d0e25Snbeams weight3d(nelem, qweight1d, v); 7487d8d0e25Snbeams } 7497d8d0e25Snbeams } 7507d8d0e25Snbeams 7517d8d0e25Snbeams ); 7527d8d0e25Snbeams // *INDENT-ON* 7537d8d0e25Snbeams 7547d8d0e25Snbeams //------------------------------------------------------------------------------ 7557d8d0e25Snbeams // Device initalization 7567d8d0e25Snbeams //------------------------------------------------------------------------------ 7577d8d0e25Snbeams int CeedHipInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d, 7587d8d0e25Snbeams CeedScalar **c_B); 7597d8d0e25Snbeams int CeedHipInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d, 7607d8d0e25Snbeams CeedInt Q1d, CeedScalar **c_B_ptr, 7617d8d0e25Snbeams CeedScalar **c_G_ptr); 7627d8d0e25Snbeams 7637d8d0e25Snbeams //------------------------------------------------------------------------------ 7647d8d0e25Snbeams // Apply basis 7657d8d0e25Snbeams //------------------------------------------------------------------------------ 7667d8d0e25Snbeams int CeedBasisApplyTensor_Hip_shared(CeedBasis basis, const CeedInt nelem, 7677d8d0e25Snbeams CeedTransposeMode tmode, 7687d8d0e25Snbeams CeedEvalMode emode, CeedVector u, 7697d8d0e25Snbeams CeedVector v) { 7707d8d0e25Snbeams int ierr; 7717d8d0e25Snbeams Ceed ceed; 7727d8d0e25Snbeams ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr); 7737d8d0e25Snbeams Ceed_Hip_shared *ceed_Hip; 7747d8d0e25Snbeams CeedGetData(ceed, &ceed_Hip); CeedChk(ierr); 7757d8d0e25Snbeams CeedBasis_Hip_shared *data; 7767d8d0e25Snbeams CeedBasisGetData(basis, &data); CeedChk(ierr); 7777d8d0e25Snbeams const CeedInt transpose = tmode == CEED_TRANSPOSE; 7787d8d0e25Snbeams CeedInt dim, ncomp; 7797d8d0e25Snbeams ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr); 7807d8d0e25Snbeams ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr); 7817d8d0e25Snbeams 7827d8d0e25Snbeams // Read vectors 7837d8d0e25Snbeams const CeedScalar *d_u; 7847d8d0e25Snbeams CeedScalar *d_v; 7857d8d0e25Snbeams if (emode != CEED_EVAL_WEIGHT) { 7867d8d0e25Snbeams ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr); 7877d8d0e25Snbeams } 7887d8d0e25Snbeams ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr); 7897d8d0e25Snbeams 7907d8d0e25Snbeams // Clear v for transpose mode 7917d8d0e25Snbeams if (tmode == CEED_TRANSPOSE) { 7927d8d0e25Snbeams CeedInt length; 7937d8d0e25Snbeams ierr = CeedVectorGetLength(v, &length); CeedChk(ierr); 7947d8d0e25Snbeams ierr = hipMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr); 7957d8d0e25Snbeams } 7967d8d0e25Snbeams 7977d8d0e25Snbeams // Apply basis operation 7987d8d0e25Snbeams switch (emode) { 7997d8d0e25Snbeams case CEED_EVAL_INTERP: { 8007d8d0e25Snbeams CeedInt P1d, Q1d; 8017d8d0e25Snbeams ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr); 8027d8d0e25Snbeams ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr); 8037d8d0e25Snbeams CeedInt thread1d = CeedIntMax(Q1d, P1d); 8047d8d0e25Snbeams ierr = CeedHipInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B); 8057d8d0e25Snbeams CeedChk(ierr); 8067d8d0e25Snbeams void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, 8077d8d0e25Snbeams &d_u, &d_v 8087d8d0e25Snbeams }; 8097d8d0e25Snbeams if (dim == 1) { 810*e7ea6884Snbeams CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64; 8117d8d0e25Snbeams elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1; 8127d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 8137d8d0e25Snbeams ? 1 : 0 ); 8147d8d0e25Snbeams CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar); 8157d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, 1, 8167d8d0e25Snbeams elemsPerBlock, sharedMem, 8177d8d0e25Snbeams interpargs); CeedChk(ierr); 8187d8d0e25Snbeams } else if (dim == 2) { 8197d8d0e25Snbeams const CeedInt optElems[7] = {0,32,8,6,4,2,6}; 8207d8d0e25Snbeams // elemsPerBlock must be at least 1 8217d8d0e25Snbeams CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1); 8227d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 8237d8d0e25Snbeams ? 1 : 0 ); 8247d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 8257d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d, 8267d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 8277d8d0e25Snbeams interpargs); CeedChk(ierr); 8287d8d0e25Snbeams } else if (dim == 3) { 8297d8d0e25Snbeams CeedInt elemsPerBlock = 1; 8307d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 8317d8d0e25Snbeams ? 1 : 0 ); 8327d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 8337d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d, 8347d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 8357d8d0e25Snbeams interpargs); CeedChk(ierr); 8367d8d0e25Snbeams } 8377d8d0e25Snbeams } break; 8387d8d0e25Snbeams case CEED_EVAL_GRAD: { 8397d8d0e25Snbeams CeedInt P1d, Q1d; 8407d8d0e25Snbeams ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr); 8417d8d0e25Snbeams ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr); 8427d8d0e25Snbeams CeedInt thread1d = CeedIntMax(Q1d, P1d); 8437d8d0e25Snbeams ierr = CeedHipInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d, 8447d8d0e25Snbeams Q1d, &data->c_B, &data->c_G); 8457d8d0e25Snbeams CeedChk(ierr); 8467d8d0e25Snbeams void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, 8477d8d0e25Snbeams &data->c_G, &d_u, &d_v 8487d8d0e25Snbeams }; 8497d8d0e25Snbeams if (dim == 1) { 850*e7ea6884Snbeams CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64; 8517d8d0e25Snbeams elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1; 8527d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 8537d8d0e25Snbeams ? 1 : 0 ); 8547d8d0e25Snbeams CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar); 8557d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, 1, 8567d8d0e25Snbeams elemsPerBlock, sharedMem, gradargs); 8577d8d0e25Snbeams CeedChk(ierr); 8587d8d0e25Snbeams } else if (dim == 2) { 8597d8d0e25Snbeams const CeedInt optElems[7] = {0,32,8,6,4,2,6}; 8607d8d0e25Snbeams // elemsPerBlock must be at least 1 8617d8d0e25Snbeams CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1); 8627d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 8637d8d0e25Snbeams ? 1 : 0 ); 8647d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 8657d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d, 8667d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 8677d8d0e25Snbeams gradargs); CeedChk(ierr); 8687d8d0e25Snbeams } else if (dim == 3) { 8697d8d0e25Snbeams CeedInt elemsPerBlock = 1; 8707d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 8717d8d0e25Snbeams ? 1 : 0 ); 8727d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 8737d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d, 8747d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 8757d8d0e25Snbeams gradargs); CeedChk(ierr); 8767d8d0e25Snbeams } 8777d8d0e25Snbeams } break; 8787d8d0e25Snbeams case CEED_EVAL_WEIGHT: { 8797d8d0e25Snbeams CeedInt Q1d; 8807d8d0e25Snbeams ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr); 8817d8d0e25Snbeams void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v}; 8827d8d0e25Snbeams if (dim == 1) { 883*e7ea6884Snbeams const CeedInt optElems = 64/Q1d; 8847d8d0e25Snbeams const CeedInt elemsPerBlock = optElems>0?optElems:1; 8857d8d0e25Snbeams const CeedInt gridsize = nelem/elemsPerBlock + ( ( 8867d8d0e25Snbeams nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 8877d8d0e25Snbeams ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, 8887d8d0e25Snbeams elemsPerBlock, 1, weightargs); 8897d8d0e25Snbeams CeedChk(ierr); 8907d8d0e25Snbeams } else if (dim == 2) { 891*e7ea6884Snbeams const CeedInt optElems = 64/(Q1d*Q1d); 8927d8d0e25Snbeams const CeedInt elemsPerBlock = optElems>0?optElems:1; 8937d8d0e25Snbeams const CeedInt gridsize = nelem/elemsPerBlock + ( ( 8947d8d0e25Snbeams nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 8957d8d0e25Snbeams ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, 8967d8d0e25Snbeams elemsPerBlock, weightargs); 8977d8d0e25Snbeams CeedChk(ierr); 8987d8d0e25Snbeams } else if (dim == 3) { 8997d8d0e25Snbeams const CeedInt gridsize = nelem; 9007d8d0e25Snbeams ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, Q1d, 9017d8d0e25Snbeams weightargs); 9027d8d0e25Snbeams CeedChk(ierr); 9037d8d0e25Snbeams } 9047d8d0e25Snbeams } break; 9057d8d0e25Snbeams // LCOV_EXCL_START 9067d8d0e25Snbeams // Evaluate the divergence to/from the quadrature points 9077d8d0e25Snbeams case CEED_EVAL_DIV: 9087d8d0e25Snbeams return CeedError(ceed, 1, "CEED_EVAL_DIV not supported"); 9097d8d0e25Snbeams // Evaluate the curl to/from the quadrature points 9107d8d0e25Snbeams case CEED_EVAL_CURL: 9117d8d0e25Snbeams return CeedError(ceed, 1, "CEED_EVAL_CURL not supported"); 9127d8d0e25Snbeams // Take no action, BasisApply should not have been called 9137d8d0e25Snbeams case CEED_EVAL_NONE: 9147d8d0e25Snbeams return CeedError(ceed, 1, 9157d8d0e25Snbeams "CEED_EVAL_NONE does not make sense in this context"); 9167d8d0e25Snbeams // LCOV_EXCL_STOP 9177d8d0e25Snbeams } 9187d8d0e25Snbeams 9197d8d0e25Snbeams // Restore vectors 9207d8d0e25Snbeams if (emode != CEED_EVAL_WEIGHT) { 9217d8d0e25Snbeams ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr); 9227d8d0e25Snbeams } 9237d8d0e25Snbeams ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr); 9247d8d0e25Snbeams return 0; 9257d8d0e25Snbeams } 9267d8d0e25Snbeams 9277d8d0e25Snbeams //------------------------------------------------------------------------------ 9287d8d0e25Snbeams // Destroy basis 9297d8d0e25Snbeams //------------------------------------------------------------------------------ 9307d8d0e25Snbeams static int CeedBasisDestroy_Hip_shared(CeedBasis basis) { 9317d8d0e25Snbeams int ierr; 9327d8d0e25Snbeams Ceed ceed; 9337d8d0e25Snbeams ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr); 9347d8d0e25Snbeams 9357d8d0e25Snbeams CeedBasis_Hip_shared *data; 9367d8d0e25Snbeams ierr = CeedBasisGetData(basis, &data); CeedChk(ierr); 9377d8d0e25Snbeams 9387d8d0e25Snbeams CeedChk_Hip(ceed, hipModuleUnload(data->module)); 9397d8d0e25Snbeams 9407d8d0e25Snbeams ierr = hipFree(data->d_qweight1d); CeedChk_Hip(ceed, ierr); 9417d8d0e25Snbeams ierr = hipFree(data->d_interp1d); CeedChk_Hip(ceed, ierr); 9427d8d0e25Snbeams ierr = hipFree(data->d_grad1d); CeedChk_Hip(ceed, ierr); 9437d8d0e25Snbeams ierr = hipFree(data->d_collograd1d); CeedChk_Hip(ceed, ierr); 9447d8d0e25Snbeams 9457d8d0e25Snbeams ierr = CeedFree(&data); CeedChk(ierr); 9467d8d0e25Snbeams 9477d8d0e25Snbeams return 0; 9487d8d0e25Snbeams } 9497d8d0e25Snbeams 9507d8d0e25Snbeams //------------------------------------------------------------------------------ 9517d8d0e25Snbeams // Create tensor basis 9527d8d0e25Snbeams //------------------------------------------------------------------------------ 9537d8d0e25Snbeams int CeedBasisCreateTensorH1_Hip_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d, 9547d8d0e25Snbeams const CeedScalar *interp1d, 9557d8d0e25Snbeams const CeedScalar *grad1d, 9567d8d0e25Snbeams const CeedScalar *qref1d, 9577d8d0e25Snbeams const CeedScalar *qweight1d, 9587d8d0e25Snbeams CeedBasis basis) { 9597d8d0e25Snbeams int ierr; 9607d8d0e25Snbeams Ceed ceed; 9617d8d0e25Snbeams ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr); 9627d8d0e25Snbeams CeedBasis_Hip_shared *data; 9637d8d0e25Snbeams ierr = CeedCalloc(1, &data); CeedChk(ierr); 9647d8d0e25Snbeams 9657d8d0e25Snbeams // Copy basis data to GPU 9667d8d0e25Snbeams const CeedInt qBytes = Q1d * sizeof(CeedScalar); 9677d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Hip(ceed, ierr); 9687d8d0e25Snbeams ierr = hipMemcpy(data->d_qweight1d, qweight1d, qBytes, 9697d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 9707d8d0e25Snbeams 9717d8d0e25Snbeams const CeedInt iBytes = qBytes * P1d; 9727d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Hip(ceed, ierr); 9737d8d0e25Snbeams ierr = hipMemcpy(data->d_interp1d, interp1d, iBytes, 9747d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 9757d8d0e25Snbeams 9767d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Hip(ceed, ierr); 9777d8d0e25Snbeams ierr = hipMemcpy(data->d_grad1d, grad1d, iBytes, 9787d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 9797d8d0e25Snbeams 9807d8d0e25Snbeams // Compute collocated gradient and copy to GPU 9817d8d0e25Snbeams data->d_collograd1d = NULL; 9827d8d0e25Snbeams if (dim == 3 && Q1d >= P1d) { 9837d8d0e25Snbeams CeedScalar *collograd1d; 9847d8d0e25Snbeams ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChk(ierr); 9857d8d0e25Snbeams ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChk(ierr); 9867d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_collograd1d, qBytes * Q1d); 9877d8d0e25Snbeams CeedChk_Hip(ceed, ierr); 9887d8d0e25Snbeams ierr = hipMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d, 9897d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 9907d8d0e25Snbeams ierr = CeedFree(&collograd1d); CeedChk(ierr); 9917d8d0e25Snbeams } 9927d8d0e25Snbeams 9937d8d0e25Snbeams // Compile basis kernels 9947d8d0e25Snbeams CeedInt ncomp; 9957d8d0e25Snbeams ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr); 9967d8d0e25Snbeams ierr = CeedCompileHip(ceed, kernelsShared, &data->module, 8, 9977d8d0e25Snbeams "Q1D", Q1d, 9987d8d0e25Snbeams "P1D", P1d, 9997d8d0e25Snbeams "T1D", CeedIntMax(Q1d, P1d), 10007d8d0e25Snbeams "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ? 10017d8d0e25Snbeams Q1d : P1d, dim), 10027d8d0e25Snbeams "BASIS_DIM", dim, 10037d8d0e25Snbeams "BASIS_NCOMP", ncomp, 10047d8d0e25Snbeams "BASIS_ELEMSIZE", CeedIntPow(P1d, dim), 10057d8d0e25Snbeams "BASIS_NQPT", CeedIntPow(Q1d, dim) 10067d8d0e25Snbeams ); CeedChk(ierr); 10077d8d0e25Snbeams ierr = CeedGetKernelHip(ceed, data->module, "interp", &data->interp); 10087d8d0e25Snbeams CeedChk(ierr); 10097d8d0e25Snbeams ierr = CeedGetKernelHip(ceed, data->module, "grad", &data->grad); 10107d8d0e25Snbeams CeedChk(ierr); 10117d8d0e25Snbeams ierr = CeedGetKernelHip(ceed, data->module, "weight", &data->weight); 10127d8d0e25Snbeams CeedChk(ierr); 10137d8d0e25Snbeams 10147d8d0e25Snbeams ierr = CeedBasisSetData(basis, data); CeedChk(ierr); 10157d8d0e25Snbeams 10167d8d0e25Snbeams // Register backend functions 10177d8d0e25Snbeams ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply", 10187d8d0e25Snbeams CeedBasisApplyTensor_Hip_shared); 10197d8d0e25Snbeams CeedChk(ierr); 10207d8d0e25Snbeams ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", 10217d8d0e25Snbeams CeedBasisDestroy_Hip_shared); CeedChk(ierr); 10227d8d0e25Snbeams return 0; 10237d8d0e25Snbeams } 10247d8d0e25Snbeams //------------------------------------------------------------------------------ 1025