17d8d0e25Snbeams // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC. 27d8d0e25Snbeams // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707. 37d8d0e25Snbeams // All Rights reserved. See files LICENSE and NOTICE for details. 47d8d0e25Snbeams // 57d8d0e25Snbeams // This file is part of CEED, a collection of benchmarks, miniapps, software 67d8d0e25Snbeams // libraries and APIs for efficient high-order finite element and spectral 77d8d0e25Snbeams // element discretizations for exascale applications. For more information and 87d8d0e25Snbeams // source code availability see http://github.com/ceed. 97d8d0e25Snbeams // 107d8d0e25Snbeams // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 117d8d0e25Snbeams // a collaborative effort of two U.S. Department of Energy organizations (Office 127d8d0e25Snbeams // of Science and the National Nuclear Security Administration) responsible for 137d8d0e25Snbeams // the planning and preparation of a capable exascale ecosystem, including 147d8d0e25Snbeams // software, applications, hardware, advanced system engineering and early 157d8d0e25Snbeams // testbed platforms, in support of the nation's exascale computing imperative. 167d8d0e25Snbeams 17*3d576824SJeremy L Thompson #include <ceed.h> 18*3d576824SJeremy L Thompson #include <ceed-backend.h> 19*3d576824SJeremy L Thompson #include <hip/hip_runtime.h> 20*3d576824SJeremy L Thompson #include <stddef.h> 217d8d0e25Snbeams #include "ceed-hip-shared.h" 22*3d576824SJeremy L Thompson #include "../hip/ceed-hip.h" 237d8d0e25Snbeams #include "../hip/ceed-hip-compile.h" 247d8d0e25Snbeams 257d8d0e25Snbeams //------------------------------------------------------------------------------ 267d8d0e25Snbeams // Shared mem kernels 277d8d0e25Snbeams //------------------------------------------------------------------------------ 287d8d0e25Snbeams // *INDENT-OFF* 297d8d0e25Snbeams static const char *kernelsShared = QUOTE( 307d8d0e25Snbeams 317d8d0e25Snbeams //------------------------------------------------------------------------------ 327d8d0e25Snbeams // Sum input into output 337d8d0e25Snbeams //------------------------------------------------------------------------------ 347d8d0e25Snbeams inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) { 357d8d0e25Snbeams for (int i = 0; i < P1D; i++) 367d8d0e25Snbeams r_V[i] += r_U[i]; 377d8d0e25Snbeams } 387d8d0e25Snbeams 397d8d0e25Snbeams //------------------------------------------------------------------------------ 407d8d0e25Snbeams // 1D 417d8d0e25Snbeams //------------------------------------------------------------------------------ 427d8d0e25Snbeams 437d8d0e25Snbeams //------------------------------------------------------------------------------ 447d8d0e25Snbeams // Read DoFs 457d8d0e25Snbeams //------------------------------------------------------------------------------ 467d8d0e25Snbeams inline __device__ void readDofs1d(const int elem, const int tidx, 477d8d0e25Snbeams const int tidy, const int tidz,const int comp, 487d8d0e25Snbeams const int nelem, const CeedScalar *d_U, 497d8d0e25Snbeams CeedScalar *slice) { 507d8d0e25Snbeams for (int i = 0; i < P1D; i++) 517d8d0e25Snbeams slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem]; 527d8d0e25Snbeams for (int i = P1D; i < Q1D; i++) 537d8d0e25Snbeams slice[i + tidz*T1D] = 0.0; 547d8d0e25Snbeams } 557d8d0e25Snbeams 567d8d0e25Snbeams //------------------------------------------------------------------------------ 577d8d0e25Snbeams // Write DoFs 587d8d0e25Snbeams //------------------------------------------------------------------------------ 597d8d0e25Snbeams inline __device__ void writeDofs1d(const int elem, const int tidx, 607d8d0e25Snbeams const int tidy, const int comp, 617d8d0e25Snbeams const int nelem, const CeedScalar &r_V, 627d8d0e25Snbeams CeedScalar *d_V) { 637d8d0e25Snbeams if (tidx<P1D) 647d8d0e25Snbeams d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V; 657d8d0e25Snbeams } 667d8d0e25Snbeams 677d8d0e25Snbeams //------------------------------------------------------------------------------ 687d8d0e25Snbeams // Read quadrature point data 697d8d0e25Snbeams //------------------------------------------------------------------------------ 707d8d0e25Snbeams inline __device__ void readQuads1d(const int elem, const int tidx, 717d8d0e25Snbeams const int tidy, const int tidz, const int comp, 727d8d0e25Snbeams const int dim, const int nelem, 737d8d0e25Snbeams const CeedScalar *d_U, CeedScalar *slice) { 747d8d0e25Snbeams for (int i = 0; i < Q1D; i++) 757d8d0e25Snbeams slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem + 767d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D]; 777d8d0e25Snbeams for (int i = Q1D; i < P1D; i++) 787d8d0e25Snbeams slice[i + tidz*T1D] = 0.0; 797d8d0e25Snbeams } 807d8d0e25Snbeams 817d8d0e25Snbeams //------------------------------------------------------------------------------ 827d8d0e25Snbeams // Write quadrature point data 837d8d0e25Snbeams //------------------------------------------------------------------------------ 847d8d0e25Snbeams inline __device__ void writeQuads1d(const int elem, const int tidx, 857d8d0e25Snbeams const int tidy, const int comp, 867d8d0e25Snbeams const int dim, const int nelem, 877d8d0e25Snbeams const CeedScalar &r_V, CeedScalar *d_V) { 887d8d0e25Snbeams if (tidx<Q1D) 897d8d0e25Snbeams d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V; 907d8d0e25Snbeams } 917d8d0e25Snbeams 927d8d0e25Snbeams //------------------------------------------------------------------------------ 937d8d0e25Snbeams // 1D tensor contraction 947d8d0e25Snbeams //------------------------------------------------------------------------------ 957d8d0e25Snbeams inline __device__ void ContractX1d(CeedScalar *slice, const int tidx, 967d8d0e25Snbeams const int tidy, const int tidz, 977d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, 987d8d0e25Snbeams CeedScalar &V) { 997d8d0e25Snbeams V = 0.0; 1007d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 1017d8d0e25Snbeams V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction 1027d8d0e25Snbeams } 1037d8d0e25Snbeams 1047d8d0e25Snbeams //------------------------------------------------------------------------------ 1057d8d0e25Snbeams // 1D transpose tensor contraction 1067d8d0e25Snbeams //------------------------------------------------------------------------------ 1077d8d0e25Snbeams inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx, 1087d8d0e25Snbeams const int tidy, const int tidz, 1097d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 1107d8d0e25Snbeams V = 0.0; 1117d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 1127d8d0e25Snbeams V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction 1137d8d0e25Snbeams } 1147d8d0e25Snbeams 1157d8d0e25Snbeams //------------------------------------------------------------------------------ 1167d8d0e25Snbeams // 1D interpolate to quadrature points 1177d8d0e25Snbeams //------------------------------------------------------------------------------ 1187d8d0e25Snbeams inline __device__ void interp1d(const CeedInt nelem, const int transpose, 1197d8d0e25Snbeams const CeedScalar *c_B, 1207d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 1217d8d0e25Snbeams CeedScalar *__restrict__ d_V, 1227d8d0e25Snbeams CeedScalar *slice) { 1237d8d0e25Snbeams CeedScalar r_V; 1247d8d0e25Snbeams CeedScalar r_t; 1257d8d0e25Snbeams 1267d8d0e25Snbeams const int tidx = threadIdx.x; 1277d8d0e25Snbeams const int tidy = threadIdx.y; 1287d8d0e25Snbeams const int tidz = threadIdx.z; 1297d8d0e25Snbeams 1307d8d0e25Snbeams 1317d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 1327d8d0e25Snbeams elem += gridDim.x*blockDim.z) { 1337d8d0e25Snbeams for (int comp = 0; comp < BASIS_NCOMP; comp++) { 1347d8d0e25Snbeams if (!transpose) { 1357d8d0e25Snbeams readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 1367d8d0e25Snbeams ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 1377d8d0e25Snbeams writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 1387d8d0e25Snbeams } else { 1397d8d0e25Snbeams readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice); 1407d8d0e25Snbeams ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 1417d8d0e25Snbeams writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 1427d8d0e25Snbeams } 1437d8d0e25Snbeams } 1447d8d0e25Snbeams } 1457d8d0e25Snbeams } 1467d8d0e25Snbeams 1477d8d0e25Snbeams //------------------------------------------------------------------------------ 1487d8d0e25Snbeams // 1D derivatives at quadrature points 1497d8d0e25Snbeams //------------------------------------------------------------------------------ 1507d8d0e25Snbeams inline __device__ void grad1d(const CeedInt nelem, const int transpose, 1517d8d0e25Snbeams const CeedScalar *c_B, const CeedScalar *c_G, 1527d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 1537d8d0e25Snbeams CeedScalar *__restrict__ d_V, 1547d8d0e25Snbeams CeedScalar *slice) { 1557d8d0e25Snbeams CeedScalar r_U; 1567d8d0e25Snbeams CeedScalar r_V; 1577d8d0e25Snbeams 1587d8d0e25Snbeams const int tidx = threadIdx.x; 1597d8d0e25Snbeams const int tidy = threadIdx.y; 1607d8d0e25Snbeams const int tidz = threadIdx.z; 1617d8d0e25Snbeams int dim; 1627d8d0e25Snbeams 1637d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 1647d8d0e25Snbeams elem += gridDim.x*blockDim.z) { 1657d8d0e25Snbeams for(int comp = 0; comp < BASIS_NCOMP; comp++) { 1667d8d0e25Snbeams if (!transpose) { 1677d8d0e25Snbeams readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 1687d8d0e25Snbeams ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 1697d8d0e25Snbeams dim = 0; 1707d8d0e25Snbeams writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 1717d8d0e25Snbeams } else { 1727d8d0e25Snbeams dim = 0; 1737d8d0e25Snbeams readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice); 1747d8d0e25Snbeams ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 1757d8d0e25Snbeams writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 1767d8d0e25Snbeams } 1777d8d0e25Snbeams } 1787d8d0e25Snbeams } 1797d8d0e25Snbeams } 1807d8d0e25Snbeams 1817d8d0e25Snbeams //------------------------------------------------------------------------------ 1827d8d0e25Snbeams // 1D Quadrature weights 1837d8d0e25Snbeams //------------------------------------------------------------------------------ 1847d8d0e25Snbeams __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d, 1857d8d0e25Snbeams CeedScalar *w) { 1867d8d0e25Snbeams const int tid = threadIdx.x; 1877d8d0e25Snbeams const CeedScalar weight = qweight1d[tid]; 1887d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem; 1897d8d0e25Snbeams elem += gridDim.x*blockDim.y) { 1907d8d0e25Snbeams const int ind = elem*Q1D + tid; 1917d8d0e25Snbeams w[ind] = weight; 1927d8d0e25Snbeams } 1937d8d0e25Snbeams } 1947d8d0e25Snbeams 1957d8d0e25Snbeams //------------------------------------------------------------------------------ 1967d8d0e25Snbeams // 2D 1977d8d0e25Snbeams //------------------------------------------------------------------------------ 1987d8d0e25Snbeams 1997d8d0e25Snbeams //------------------------------------------------------------------------------ 2007d8d0e25Snbeams // Read DoFs 2017d8d0e25Snbeams //------------------------------------------------------------------------------ 2027d8d0e25Snbeams inline __device__ void readDofs2d(const int elem, const int tidx, 2037d8d0e25Snbeams const int tidy, const int comp, 2047d8d0e25Snbeams const int nelem, const CeedScalar *d_U, 2057d8d0e25Snbeams CeedScalar &U) { 2067d8d0e25Snbeams U = (tidx<P1D && tidy<P1D) ? 2077d8d0e25Snbeams d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0; 2087d8d0e25Snbeams } 2097d8d0e25Snbeams 2107d8d0e25Snbeams //------------------------------------------------------------------------------ 2117d8d0e25Snbeams // Write DoFs 2127d8d0e25Snbeams //------------------------------------------------------------------------------ 2137d8d0e25Snbeams inline __device__ void writeDofs2d(const int elem, const int tidx, 2147d8d0e25Snbeams const int tidy, const int comp, 2157d8d0e25Snbeams const int nelem, const CeedScalar &r_V, 2167d8d0e25Snbeams CeedScalar *d_V) { 2177d8d0e25Snbeams if (tidx<P1D && tidy<P1D) 2187d8d0e25Snbeams d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V; 2197d8d0e25Snbeams } 2207d8d0e25Snbeams 2217d8d0e25Snbeams //------------------------------------------------------------------------------ 2227d8d0e25Snbeams // Read quadrature point data 2237d8d0e25Snbeams //------------------------------------------------------------------------------ 2247d8d0e25Snbeams inline __device__ void readQuads2d(const int elem, const int tidx, 2257d8d0e25Snbeams const int tidy, const int comp, 2267d8d0e25Snbeams const int dim, const int nelem, 2277d8d0e25Snbeams const CeedScalar *d_U, CeedScalar &U ) { 2287d8d0e25Snbeams U = (tidx<Q1D && tidy<Q1D) ? 2297d8d0e25Snbeams d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 2307d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0; 2317d8d0e25Snbeams } 2327d8d0e25Snbeams 2337d8d0e25Snbeams //------------------------------------------------------------------------------ 2347d8d0e25Snbeams // Write quadrature point data 2357d8d0e25Snbeams //------------------------------------------------------------------------------ 2367d8d0e25Snbeams inline __device__ void writeQuads2d(const int elem, const int tidx, 2377d8d0e25Snbeams const int tidy, const int comp, 2387d8d0e25Snbeams const int dim, const int nelem, 2397d8d0e25Snbeams const CeedScalar &r_V, CeedScalar *d_V) { 2407d8d0e25Snbeams if (tidx<Q1D && tidy<Q1D) 2417d8d0e25Snbeams d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 2427d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V; 2437d8d0e25Snbeams } 2447d8d0e25Snbeams 2457d8d0e25Snbeams //------------------------------------------------------------------------------ 2467d8d0e25Snbeams // 2D tensor contraction x 2477d8d0e25Snbeams //------------------------------------------------------------------------------ 2487d8d0e25Snbeams inline __device__ void ContractX2d(CeedScalar *slice, const int tidx, 2497d8d0e25Snbeams const int tidy, const int tidz, 2507d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, 2517d8d0e25Snbeams CeedScalar &V) { 2527d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 2537d8d0e25Snbeams __syncthreads(); 2547d8d0e25Snbeams V = 0.0; 2557d8d0e25Snbeams if (tidx < Q1D) 2567d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 2577d8d0e25Snbeams V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 2587d8d0e25Snbeams __syncthreads(); 2597d8d0e25Snbeams } 2607d8d0e25Snbeams 2617d8d0e25Snbeams //------------------------------------------------------------------------------ 2627d8d0e25Snbeams // 2D tensor contraction y 2637d8d0e25Snbeams //------------------------------------------------------------------------------ 2647d8d0e25Snbeams inline __device__ void ContractY2d(CeedScalar *slice, const int tidx, 2657d8d0e25Snbeams const int tidy, const int tidz, 2667d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, 2677d8d0e25Snbeams CeedScalar &V) { 2687d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 2697d8d0e25Snbeams __syncthreads(); 2707d8d0e25Snbeams V = 0.0; 2717d8d0e25Snbeams if (tidy < Q1D) 2727d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 2737d8d0e25Snbeams V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 2747d8d0e25Snbeams __syncthreads(); 2757d8d0e25Snbeams } 2767d8d0e25Snbeams 2777d8d0e25Snbeams //------------------------------------------------------------------------------ 2787d8d0e25Snbeams // 2D transpose tensor contraction y 2797d8d0e25Snbeams //------------------------------------------------------------------------------ 2807d8d0e25Snbeams inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx, 2817d8d0e25Snbeams const int tidy, const int tidz, 2827d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 2837d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 2847d8d0e25Snbeams __syncthreads(); 2857d8d0e25Snbeams V = 0.0; 2867d8d0e25Snbeams if (tidy < P1D) 2877d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 2887d8d0e25Snbeams V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 2897d8d0e25Snbeams __syncthreads(); 2907d8d0e25Snbeams } 2917d8d0e25Snbeams 2927d8d0e25Snbeams //------------------------------------------------------------------------------ 2937d8d0e25Snbeams // 2D transpose tensor contraction x 2947d8d0e25Snbeams //------------------------------------------------------------------------------ 2957d8d0e25Snbeams inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx, 2967d8d0e25Snbeams const int tidy, const int tidz, 2977d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 2987d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 2997d8d0e25Snbeams __syncthreads(); 3007d8d0e25Snbeams V = 0.0; 3017d8d0e25Snbeams if (tidx < P1D) 3027d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 3037d8d0e25Snbeams V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 3047d8d0e25Snbeams __syncthreads(); 3057d8d0e25Snbeams } 3067d8d0e25Snbeams 3077d8d0e25Snbeams //------------------------------------------------------------------------------ 3087d8d0e25Snbeams // 2D interpolate to quadrature points 3097d8d0e25Snbeams //------------------------------------------------------------------------------ 3107d8d0e25Snbeams inline __device__ void interp2d(const CeedInt nelem, const int transpose, 3117d8d0e25Snbeams const CeedScalar *c_B, 3127d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 3137d8d0e25Snbeams CeedScalar *__restrict__ d_V, 3147d8d0e25Snbeams CeedScalar *slice) { 3157d8d0e25Snbeams CeedScalar r_V; 3167d8d0e25Snbeams CeedScalar r_t; 3177d8d0e25Snbeams 3187d8d0e25Snbeams const int tidx = threadIdx.x; 3197d8d0e25Snbeams const int tidy = threadIdx.y; 3207d8d0e25Snbeams const int tidz = threadIdx.z; 3217d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 3227d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 3237d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 3247d8d0e25Snbeams 3257d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 3267d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 3277d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 3287d8d0e25Snbeams r_V = 0.0; 3297d8d0e25Snbeams r_t = 0.0; 3307d8d0e25Snbeams if (!transpose) { 3317d8d0e25Snbeams readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V); 3327d8d0e25Snbeams ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 3337d8d0e25Snbeams ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 3347d8d0e25Snbeams writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 3357d8d0e25Snbeams } else { 3367d8d0e25Snbeams readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 3377d8d0e25Snbeams ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 3387d8d0e25Snbeams ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 3397d8d0e25Snbeams writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 3407d8d0e25Snbeams } 3417d8d0e25Snbeams } 3427d8d0e25Snbeams } 3437d8d0e25Snbeams 3447d8d0e25Snbeams //------------------------------------------------------------------------------ 3457d8d0e25Snbeams // 2D derivatives at quadrature points 3467d8d0e25Snbeams //------------------------------------------------------------------------------ 3477d8d0e25Snbeams inline __device__ void grad2d(const CeedInt nelem, const int transpose, 3487d8d0e25Snbeams const CeedScalar *c_B, const CeedScalar *c_G, 3497d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 3507d8d0e25Snbeams CeedScalar *__restrict__ d_V, CeedScalar *slice) { 3517d8d0e25Snbeams CeedScalar r_U; 3527d8d0e25Snbeams CeedScalar r_V; 3537d8d0e25Snbeams CeedScalar r_t; 3547d8d0e25Snbeams 3557d8d0e25Snbeams const int tidx = threadIdx.x; 3567d8d0e25Snbeams const int tidy = threadIdx.y; 3577d8d0e25Snbeams const int tidz = threadIdx.z; 3587d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 3597d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 3607d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 3617d8d0e25Snbeams int dim; 3627d8d0e25Snbeams 3637d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 3647d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 3657d8d0e25Snbeams if (!transpose) { 3667d8d0e25Snbeams readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U); 3677d8d0e25Snbeams ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 3687d8d0e25Snbeams ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 3697d8d0e25Snbeams dim = 0; 3707d8d0e25Snbeams writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 3717d8d0e25Snbeams ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 3727d8d0e25Snbeams ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 3737d8d0e25Snbeams dim = 1; 3747d8d0e25Snbeams writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 3757d8d0e25Snbeams } else { 3767d8d0e25Snbeams dim = 0; 3777d8d0e25Snbeams readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 3787d8d0e25Snbeams ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 3797d8d0e25Snbeams ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 3807d8d0e25Snbeams dim = 1; 3817d8d0e25Snbeams readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 3827d8d0e25Snbeams ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 3837d8d0e25Snbeams ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 3847d8d0e25Snbeams r_V += r_U; 3857d8d0e25Snbeams writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 3867d8d0e25Snbeams } 3877d8d0e25Snbeams } 3887d8d0e25Snbeams } 3897d8d0e25Snbeams 3907d8d0e25Snbeams //------------------------------------------------------------------------------ 3917d8d0e25Snbeams // 2D quadrature weights 3927d8d0e25Snbeams //------------------------------------------------------------------------------ 3937d8d0e25Snbeams __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d, 3947d8d0e25Snbeams CeedScalar *w) { 3957d8d0e25Snbeams const int i = threadIdx.x; 3967d8d0e25Snbeams const int j = threadIdx.y; 3977d8d0e25Snbeams const CeedScalar weight = qweight1d[i]*qweight1d[j]; 3987d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 3997d8d0e25Snbeams elem += gridDim.x*blockDim.z) { 4007d8d0e25Snbeams const int ind = elem*Q1D*Q1D + i + j*Q1D; 4017d8d0e25Snbeams w[ind] = weight; 4027d8d0e25Snbeams } 4037d8d0e25Snbeams } 4047d8d0e25Snbeams 4057d8d0e25Snbeams //------------------------------------------------------------------------------ 4067d8d0e25Snbeams // 3D 4077d8d0e25Snbeams //------------------------------------------------------------------------------ 4087d8d0e25Snbeams 4097d8d0e25Snbeams //------------------------------------------------------------------------------ 4107d8d0e25Snbeams // Read DoFs 4117d8d0e25Snbeams //------------------------------------------------------------------------------ 4127d8d0e25Snbeams inline __device__ void readDofs3d(const int elem, const int tidx, 4137d8d0e25Snbeams const int tidy, const int comp, 4147d8d0e25Snbeams const int nelem, const CeedScalar *d_U, 4157d8d0e25Snbeams CeedScalar *r_U) { 4167d8d0e25Snbeams for (int i = 0; i < P1D; i++) 4177d8d0e25Snbeams r_U[i] = (tidx < P1D && tidy < P1D) ? 4187d8d0e25Snbeams d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D + 4197d8d0e25Snbeams comp*P1D*P1D*P1D*nelem] : 0.0; 4207d8d0e25Snbeams for (int i = P1D; i < Q1D; i++) 4217d8d0e25Snbeams r_U[i] = 0.0; 4227d8d0e25Snbeams } 4237d8d0e25Snbeams 4247d8d0e25Snbeams //------------------------------------------------------------------------------ 4257d8d0e25Snbeams // Write DoFs 4267d8d0e25Snbeams //------------------------------------------------------------------------------ 4277d8d0e25Snbeams inline __device__ void writeDofs3d(const int elem, const int tidx, 4287d8d0e25Snbeams const int tidy, const int comp, 4297d8d0e25Snbeams const int nelem, const CeedScalar *r_V, 4307d8d0e25Snbeams CeedScalar *d_V) { 4317d8d0e25Snbeams if (tidx < P1D && tidy < P1D) { 4327d8d0e25Snbeams for (int i = 0; i < P1D; i++) 4337d8d0e25Snbeams d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D + 4347d8d0e25Snbeams comp*P1D*P1D*P1D*nelem] = r_V[i]; 4357d8d0e25Snbeams } 4367d8d0e25Snbeams } 4377d8d0e25Snbeams 4387d8d0e25Snbeams //------------------------------------------------------------------------------ 4397d8d0e25Snbeams // Read quadrature point data 4407d8d0e25Snbeams //------------------------------------------------------------------------------ 4417d8d0e25Snbeams inline __device__ void readQuads3d(const int elem, const int tidx, 4427d8d0e25Snbeams const int tidy, const int comp, 4437d8d0e25Snbeams const int dim, const int nelem, 4447d8d0e25Snbeams const CeedScalar *d_U, CeedScalar *r_U) { 4457d8d0e25Snbeams for (int i = 0; i < Q1D; i++) 4467d8d0e25Snbeams r_U[i] = (tidx < Q1D && tidy < Q1D) ? 4477d8d0e25Snbeams d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + 4487d8d0e25Snbeams comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0; 4497d8d0e25Snbeams for (int i = Q1D; i < P1D; i++) 4507d8d0e25Snbeams r_U[i] = 0.0; 4517d8d0e25Snbeams } 4527d8d0e25Snbeams 4537d8d0e25Snbeams //------------------------------------------------------------------------------ 4547d8d0e25Snbeams // Write quadrature point data 4557d8d0e25Snbeams //------------------------------------------------------------------------------ 4567d8d0e25Snbeams inline __device__ void writeQuads3d(const int elem, const int tidx, 4577d8d0e25Snbeams const int tidy, const int comp, 4587d8d0e25Snbeams const int dim, const int nelem, 4597d8d0e25Snbeams const CeedScalar *r_V, CeedScalar *d_V) { 4607d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) { 4617d8d0e25Snbeams for (int i = 0; i < Q1D; i++) 4627d8d0e25Snbeams d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem + 4637d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i]; 4647d8d0e25Snbeams } 4657d8d0e25Snbeams } 4667d8d0e25Snbeams 4677d8d0e25Snbeams //------------------------------------------------------------------------------ 4687d8d0e25Snbeams // 3D tensor contract x 4697d8d0e25Snbeams //------------------------------------------------------------------------------ 4707d8d0e25Snbeams inline __device__ void ContractX3d(CeedScalar *slice, const int tidx, 4717d8d0e25Snbeams const int tidy, const int tidz, 4727d8d0e25Snbeams const CeedScalar *U, 4737d8d0e25Snbeams const CeedScalar *B, 4747d8d0e25Snbeams CeedScalar *V) { 4757d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 4767d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 4777d8d0e25Snbeams __syncthreads(); 4787d8d0e25Snbeams V[k] = 0.0; 4797d8d0e25Snbeams if (tidx < Q1D && tidy < P1D) 4807d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 4817d8d0e25Snbeams V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 4827d8d0e25Snbeams __syncthreads(); 4837d8d0e25Snbeams } 4847d8d0e25Snbeams } 4857d8d0e25Snbeams 4867d8d0e25Snbeams //------------------------------------------------------------------------------ 4877d8d0e25Snbeams // 3D tensor contract y 4887d8d0e25Snbeams //------------------------------------------------------------------------------ 4897d8d0e25Snbeams inline __device__ void ContractY3d(CeedScalar *slice, const int tidx, 4907d8d0e25Snbeams const int tidy, const int tidz, 4917d8d0e25Snbeams const CeedScalar *U, 4927d8d0e25Snbeams const CeedScalar *B, 4937d8d0e25Snbeams CeedScalar *V) { 4947d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 4957d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 4967d8d0e25Snbeams __syncthreads(); 4977d8d0e25Snbeams V[k] = 0.0; 4987d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) 4997d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 5007d8d0e25Snbeams V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 5017d8d0e25Snbeams __syncthreads(); 5027d8d0e25Snbeams } 5037d8d0e25Snbeams } 5047d8d0e25Snbeams 5057d8d0e25Snbeams //------------------------------------------------------------------------------ 5067d8d0e25Snbeams // 3D tensor contract z 5077d8d0e25Snbeams //------------------------------------------------------------------------------ 5087d8d0e25Snbeams inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx, 5097d8d0e25Snbeams const int tidy, const int tidz, 5107d8d0e25Snbeams const CeedScalar *U, 5117d8d0e25Snbeams const CeedScalar *B, 5127d8d0e25Snbeams CeedScalar *V) { 5137d8d0e25Snbeams for (int k = 0; k < Q1D; ++k) { 5147d8d0e25Snbeams V[k] = 0.0; 5157d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) 5167d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 5177d8d0e25Snbeams V[k] += B[i + k*P1D] * U[i]; // Contract z direction 5187d8d0e25Snbeams } 5197d8d0e25Snbeams for (int k = Q1D; k < P1D; ++k) 5207d8d0e25Snbeams V[k] = 0.0; 5217d8d0e25Snbeams } 5227d8d0e25Snbeams 5237d8d0e25Snbeams //------------------------------------------------------------------------------ 5247d8d0e25Snbeams // 3D transpose tensor contract z 5257d8d0e25Snbeams //------------------------------------------------------------------------------ 5267d8d0e25Snbeams inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx, 5277d8d0e25Snbeams const int tidy, const int tidz, 5287d8d0e25Snbeams const CeedScalar *U, 5297d8d0e25Snbeams const CeedScalar *B, 5307d8d0e25Snbeams CeedScalar *V) { 5317d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 5327d8d0e25Snbeams V[k] = 0.0; 5337d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) 5347d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 5357d8d0e25Snbeams V[k] += B[k + i*P1D] * U[i]; // Contract z direction 5367d8d0e25Snbeams } 5377d8d0e25Snbeams for (int k = P1D; k < Q1D; ++k) 5387d8d0e25Snbeams V[k] = 0.0; 5397d8d0e25Snbeams } 5407d8d0e25Snbeams 5417d8d0e25Snbeams //------------------------------------------------------------------------------ 5427d8d0e25Snbeams // 3D transpose tensor contract y 5437d8d0e25Snbeams //------------------------------------------------------------------------------ 5447d8d0e25Snbeams inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx, 5457d8d0e25Snbeams const int tidy, const int tidz, 5467d8d0e25Snbeams const CeedScalar *U, 5477d8d0e25Snbeams const CeedScalar *B, 5487d8d0e25Snbeams CeedScalar *V) { 5497d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 5507d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 5517d8d0e25Snbeams __syncthreads(); 5527d8d0e25Snbeams V[k] = 0.0; 5537d8d0e25Snbeams if (tidx < Q1D && tidy < P1D) 5547d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 5557d8d0e25Snbeams V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 5567d8d0e25Snbeams __syncthreads(); 5577d8d0e25Snbeams } 5587d8d0e25Snbeams } 5597d8d0e25Snbeams 5607d8d0e25Snbeams //------------------------------------------------------------------------------ 5617d8d0e25Snbeams // 3D transpose tensor contract x 5627d8d0e25Snbeams //------------------------------------------------------------------------------ 5637d8d0e25Snbeams inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx, 5647d8d0e25Snbeams const int tidy, const int tidz, 5657d8d0e25Snbeams const CeedScalar *U, 5667d8d0e25Snbeams const CeedScalar *B, 5677d8d0e25Snbeams CeedScalar *V) { 5687d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 5697d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 5707d8d0e25Snbeams __syncthreads(); 5717d8d0e25Snbeams V[k] = 0.0; 5727d8d0e25Snbeams if (tidx < P1D && tidy < P1D) 5737d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 5747d8d0e25Snbeams V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 5757d8d0e25Snbeams __syncthreads(); 5767d8d0e25Snbeams } 5777d8d0e25Snbeams } 5787d8d0e25Snbeams 5797d8d0e25Snbeams //------------------------------------------------------------------------------ 5807d8d0e25Snbeams // 3D interpolate to quadrature points 5817d8d0e25Snbeams //------------------------------------------------------------------------------ 5827d8d0e25Snbeams inline __device__ void interp3d(const CeedInt nelem, const int transpose, 5837d8d0e25Snbeams const CeedScalar *c_B, 5847d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 5857d8d0e25Snbeams CeedScalar *__restrict__ d_V, 5867d8d0e25Snbeams CeedScalar *slice) { 5877d8d0e25Snbeams CeedScalar r_V[T1D]; 5887d8d0e25Snbeams CeedScalar r_t[T1D]; 5897d8d0e25Snbeams 5907d8d0e25Snbeams const int tidx = threadIdx.x; 5917d8d0e25Snbeams const int tidy = threadIdx.y; 5927d8d0e25Snbeams const int tidz = threadIdx.z; 5937d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 5947d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 5957d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 5967d8d0e25Snbeams 5977d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 5987d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 5997d8d0e25Snbeams for (int i = 0; i < T1D; ++i) { 6007d8d0e25Snbeams r_V[i] = 0.0; 6017d8d0e25Snbeams r_t[i] = 0.0; 6027d8d0e25Snbeams } 6037d8d0e25Snbeams if (!transpose) { 6047d8d0e25Snbeams readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V); 6057d8d0e25Snbeams ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 6067d8d0e25Snbeams ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 6077d8d0e25Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 6087d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V); 6097d8d0e25Snbeams } else { 6107d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 6117d8d0e25Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 6127d8d0e25Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 6137d8d0e25Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 6147d8d0e25Snbeams writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V); 6157d8d0e25Snbeams } 6167d8d0e25Snbeams } 6177d8d0e25Snbeams } 6187d8d0e25Snbeams 6197d8d0e25Snbeams //------------------------------------------------------------------------------ 6207d8d0e25Snbeams // 3D derivatives at quadrature points 6217d8d0e25Snbeams //------------------------------------------------------------------------------ 6227d8d0e25Snbeams inline __device__ void grad3d(const CeedInt nelem, const int transpose, 6237d8d0e25Snbeams const CeedScalar *c_B, const CeedScalar *c_G, 6247d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 6257d8d0e25Snbeams CeedScalar *__restrict__ d_V, 6267d8d0e25Snbeams CeedScalar *slice) { 6277d8d0e25Snbeams // Use P1D for one of these 6287d8d0e25Snbeams CeedScalar r_U[T1D]; 6297d8d0e25Snbeams CeedScalar r_V[T1D]; 6307d8d0e25Snbeams CeedScalar r_t[T1D]; 6317d8d0e25Snbeams 6327d8d0e25Snbeams const int tidx = threadIdx.x; 6337d8d0e25Snbeams const int tidy = threadIdx.y; 6347d8d0e25Snbeams const int tidz = threadIdx.z; 6357d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 6367d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 6377d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 6387d8d0e25Snbeams int dim; 6397d8d0e25Snbeams 6407d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 6417d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 6427d8d0e25Snbeams for (int i = 0; i < T1D; ++i) { 6437d8d0e25Snbeams r_U[i] = 0.0; 6447d8d0e25Snbeams r_V[i] = 0.0; 6457d8d0e25Snbeams r_t[i] = 0.0; 6467d8d0e25Snbeams } 6477d8d0e25Snbeams if (!transpose) { 6487d8d0e25Snbeams readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U); 6497d8d0e25Snbeams ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 6507d8d0e25Snbeams ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 6517d8d0e25Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 6527d8d0e25Snbeams dim = 0; 6537d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 6547d8d0e25Snbeams ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V); 6557d8d0e25Snbeams ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t); 6567d8d0e25Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 6577d8d0e25Snbeams dim = 1; 6587d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 6597d8d0e25Snbeams ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V); 6607d8d0e25Snbeams ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 6617d8d0e25Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 6627d8d0e25Snbeams dim = 2; 6637d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 6647d8d0e25Snbeams } else { 6657d8d0e25Snbeams dim = 0; 6667d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 6677d8d0e25Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 6687d8d0e25Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 6697d8d0e25Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 6707d8d0e25Snbeams dim = 1; 6717d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 6727d8d0e25Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 6737d8d0e25Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U); 6747d8d0e25Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 6757d8d0e25Snbeams add(r_V, r_t); 6767d8d0e25Snbeams dim = 2; 6777d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 6787d8d0e25Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 6797d8d0e25Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 6807d8d0e25Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 6817d8d0e25Snbeams add(r_V, r_t); 6827d8d0e25Snbeams writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V); 6837d8d0e25Snbeams } 6847d8d0e25Snbeams } 6857d8d0e25Snbeams } 6867d8d0e25Snbeams 6877d8d0e25Snbeams //------------------------------------------------------------------------------ 6887d8d0e25Snbeams // 3D quadrature weights 6897d8d0e25Snbeams //------------------------------------------------------------------------------ 6907d8d0e25Snbeams __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d, 6917d8d0e25Snbeams CeedScalar *w) { 6927d8d0e25Snbeams const int i = threadIdx.x; 6937d8d0e25Snbeams const int j = threadIdx.y; 6947d8d0e25Snbeams const int k = threadIdx.z; 6957d8d0e25Snbeams const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k]; 6967d8d0e25Snbeams for (int e = blockIdx.x; e < nelem; e += gridDim.x) { 6977d8d0e25Snbeams const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D; 6987d8d0e25Snbeams w[ind] = weight; 6997d8d0e25Snbeams } 7007d8d0e25Snbeams } 7017d8d0e25Snbeams 7027d8d0e25Snbeams 7037d8d0e25Snbeams //------------------------------------------------------------------------------ 7047d8d0e25Snbeams // Basis kernels 7057d8d0e25Snbeams //------------------------------------------------------------------------------ 7067d8d0e25Snbeams 7077d8d0e25Snbeams //------------------------------------------------------------------------------ 7087d8d0e25Snbeams // Interp kernel by dim 7097d8d0e25Snbeams //------------------------------------------------------------------------------ 7109e31c45bSnbeams extern "C" __launch_bounds__(INTERP_BLKSIZE) __global__ void interp( 7119e31c45bSnbeams const CeedInt nelem, const int transpose, 7127d8d0e25Snbeams const CeedScalar *c_B, 7137d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 7147d8d0e25Snbeams CeedScalar *__restrict__ d_V) { 7157d8d0e25Snbeams HIP_DYNAMIC_SHARED( double, slice) 7167d8d0e25Snbeams if (BASIS_DIM == 1) { 7177d8d0e25Snbeams interp1d(nelem, transpose, c_B, d_U, d_V, slice); 7187d8d0e25Snbeams } else if (BASIS_DIM == 2) { 7197d8d0e25Snbeams interp2d(nelem, transpose, c_B, d_U, d_V, slice); 7207d8d0e25Snbeams } else if (BASIS_DIM == 3) { 7217d8d0e25Snbeams interp3d(nelem, transpose, c_B, d_U, d_V, slice); 7227d8d0e25Snbeams } 7237d8d0e25Snbeams } 7247d8d0e25Snbeams 7257d8d0e25Snbeams //------------------------------------------------------------------------------ 7267d8d0e25Snbeams // Grad kernel by dim 7277d8d0e25Snbeams //------------------------------------------------------------------------------ 7289e31c45bSnbeams extern "C" __launch_bounds__(GRAD_BLKSIZE) __global__ void grad(const CeedInt nelem, 7299e31c45bSnbeams const int transpose, 7307d8d0e25Snbeams const CeedScalar *c_B, const CeedScalar *c_G, 7317d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 7327d8d0e25Snbeams CeedScalar *__restrict__ d_V) { 7337d8d0e25Snbeams HIP_DYNAMIC_SHARED( double, slice) 7347d8d0e25Snbeams if (BASIS_DIM == 1) { 7357d8d0e25Snbeams grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 7367d8d0e25Snbeams } else if (BASIS_DIM == 2) { 7377d8d0e25Snbeams grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 7387d8d0e25Snbeams } else if (BASIS_DIM == 3) { 7397d8d0e25Snbeams grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 7407d8d0e25Snbeams } 7417d8d0e25Snbeams } 7427d8d0e25Snbeams 7437d8d0e25Snbeams //------------------------------------------------------------------------------ 7447d8d0e25Snbeams // Weight kernels by dim 7457d8d0e25Snbeams //------------------------------------------------------------------------------ 7469e31c45bSnbeams extern "C" __launch_bounds__(WEIGHT_BLKSIZE) __global__ void weight(const CeedInt nelem, 7477d8d0e25Snbeams const CeedScalar *__restrict__ qweight1d, 7487d8d0e25Snbeams CeedScalar *__restrict__ v) { 7497d8d0e25Snbeams if (BASIS_DIM == 1) { 7507d8d0e25Snbeams weight1d(nelem, qweight1d, v); 7517d8d0e25Snbeams } else if (BASIS_DIM == 2) { 7527d8d0e25Snbeams weight2d(nelem, qweight1d, v); 7537d8d0e25Snbeams } else if (BASIS_DIM == 3) { 7547d8d0e25Snbeams weight3d(nelem, qweight1d, v); 7557d8d0e25Snbeams } 7567d8d0e25Snbeams } 7577d8d0e25Snbeams 7587d8d0e25Snbeams ); 7597d8d0e25Snbeams // *INDENT-ON* 7607d8d0e25Snbeams 7617d8d0e25Snbeams //------------------------------------------------------------------------------ 7629e31c45bSnbeams // Compute a block size based on required minimum threads 7639e31c45bSnbeams //------------------------------------------------------------------------------ 7649e31c45bSnbeams static CeedInt ComputeBlockSizeFromRequirement(const CeedInt required) { 7659e31c45bSnbeams CeedInt maxSize = 1024; // Max total threads per block 7669e31c45bSnbeams CeedInt currentSize = 64; // Start with one group 7679e31c45bSnbeams 7689e31c45bSnbeams while(currentSize < maxSize) { 7699e31c45bSnbeams if (currentSize > required) 7709e31c45bSnbeams break; 7719e31c45bSnbeams else 7729e31c45bSnbeams currentSize = currentSize * 2; 7739e31c45bSnbeams } 7749e31c45bSnbeams return currentSize; 7759e31c45bSnbeams } 7769e31c45bSnbeams 7779e31c45bSnbeams //------------------------------------------------------------------------------ 7789e31c45bSnbeams // Compute required thread block sizes for basis kernels given P, Q, dim, and 7799e31c45bSnbeams // ncomp 7809e31c45bSnbeams //------------------------------------------------------------------------------ 7819e31c45bSnbeams static int ComputeBasisThreadBlockSizes(const CeedInt dim, const CeedInt P1d, 7829e31c45bSnbeams const CeedInt Q1d, 7839e31c45bSnbeams const CeedInt ncomp, CeedInt *blksizes) { 7849e31c45bSnbeams 7859e31c45bSnbeams // Note that this will use the same block sizes for all dimensions when compiling, 7869e31c45bSnbeams // but as each basis object is defined for a particular dimension, we will never 7879e31c45bSnbeams // call any kernels except the ones for the dimension for which we have computed the 7889e31c45bSnbeams // block sizes. 7899e31c45bSnbeams const CeedInt thread1d = CeedIntMax(P1d, Q1d); 7909e31c45bSnbeams switch (dim) { 7919e31c45bSnbeams case 1: { 7929e31c45bSnbeams // Interp kernels: 7939e31c45bSnbeams blksizes[0] = 256; 7949e31c45bSnbeams 7959e31c45bSnbeams // Grad kernels: 7969e31c45bSnbeams blksizes[1] = 256; 7979e31c45bSnbeams 7989e31c45bSnbeams // Weight kernels: 7999e31c45bSnbeams blksizes[2] = 256; 8009e31c45bSnbeams 8019e31c45bSnbeams } break; 8029e31c45bSnbeams case 2: { 8039e31c45bSnbeams // Interp kernels: 8049e31c45bSnbeams CeedInt required = thread1d * thread1d * ncomp; 8059e31c45bSnbeams blksizes[0] = ComputeBlockSizeFromRequirement(required); 8069e31c45bSnbeams 8079e31c45bSnbeams // Grad kernels: currently use same required minimum threads 8089e31c45bSnbeams blksizes[1] = ComputeBlockSizeFromRequirement(required); 8099e31c45bSnbeams 8109e31c45bSnbeams // Weight kernels: 8119e31c45bSnbeams required = CeedIntMax(64, Q1d * Q1d); 8129e31c45bSnbeams blksizes[2] = ComputeBlockSizeFromRequirement(required); 8139e31c45bSnbeams 8149e31c45bSnbeams } break; 8159e31c45bSnbeams case 3: { 8169e31c45bSnbeams // Interp kernels: 8179e31c45bSnbeams CeedInt required = thread1d * thread1d * ncomp; 8189e31c45bSnbeams blksizes[0] = ComputeBlockSizeFromRequirement(required); 8199e31c45bSnbeams 8209e31c45bSnbeams // Grad kernels: currently use same required minimum threads 8219e31c45bSnbeams blksizes[1] = ComputeBlockSizeFromRequirement(required); 8229e31c45bSnbeams 8239e31c45bSnbeams // Weight kernels: 8249e31c45bSnbeams required = Q1d * Q1d * Q1d; 8259e31c45bSnbeams blksizes[2] = ComputeBlockSizeFromRequirement(required); 8269e31c45bSnbeams } 8279e31c45bSnbeams } 8289e31c45bSnbeams 8299e31c45bSnbeams return 0; 8309e31c45bSnbeams } 8319e31c45bSnbeams 8329e31c45bSnbeams //------------------------------------------------------------------------------ 8337d8d0e25Snbeams // Apply basis 8347d8d0e25Snbeams //------------------------------------------------------------------------------ 8357d8d0e25Snbeams int CeedBasisApplyTensor_Hip_shared(CeedBasis basis, const CeedInt nelem, 8367d8d0e25Snbeams CeedTransposeMode tmode, 8377d8d0e25Snbeams CeedEvalMode emode, CeedVector u, 8387d8d0e25Snbeams CeedVector v) { 8397d8d0e25Snbeams int ierr; 8407d8d0e25Snbeams Ceed ceed; 8417d8d0e25Snbeams ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr); 8427d8d0e25Snbeams Ceed_Hip_shared *ceed_Hip; 8437d8d0e25Snbeams CeedGetData(ceed, &ceed_Hip); CeedChk(ierr); 8447d8d0e25Snbeams CeedBasis_Hip_shared *data; 8457d8d0e25Snbeams CeedBasisGetData(basis, &data); CeedChk(ierr); 8467d8d0e25Snbeams const CeedInt transpose = tmode == CEED_TRANSPOSE; 8477d8d0e25Snbeams CeedInt dim, ncomp; 8487d8d0e25Snbeams ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr); 8497d8d0e25Snbeams ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr); 8507d8d0e25Snbeams 8517d8d0e25Snbeams // Read vectors 8527d8d0e25Snbeams const CeedScalar *d_u; 8537d8d0e25Snbeams CeedScalar *d_v; 8547d8d0e25Snbeams if (emode != CEED_EVAL_WEIGHT) { 8557d8d0e25Snbeams ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr); 8567d8d0e25Snbeams } 8577d8d0e25Snbeams ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr); 8587d8d0e25Snbeams 8597d8d0e25Snbeams // Clear v for transpose mode 8607d8d0e25Snbeams if (tmode == CEED_TRANSPOSE) { 8617d8d0e25Snbeams CeedInt length; 8627d8d0e25Snbeams ierr = CeedVectorGetLength(v, &length); CeedChk(ierr); 8637d8d0e25Snbeams ierr = hipMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr); 8647d8d0e25Snbeams } 8657d8d0e25Snbeams 8667d8d0e25Snbeams // Apply basis operation 8677d8d0e25Snbeams switch (emode) { 8687d8d0e25Snbeams case CEED_EVAL_INTERP: { 8697d8d0e25Snbeams CeedInt P1d, Q1d; 8709e31c45bSnbeams CeedInt blksize = data->blksizes[0]; 8717d8d0e25Snbeams ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr); 8727d8d0e25Snbeams ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr); 8737d8d0e25Snbeams CeedInt thread1d = CeedIntMax(Q1d, P1d); 8747d8d0e25Snbeams ierr = CeedHipInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B); 8757d8d0e25Snbeams CeedChk(ierr); 8767d8d0e25Snbeams void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, 8777d8d0e25Snbeams &d_u, &d_v 8787d8d0e25Snbeams }; 8797d8d0e25Snbeams if (dim == 1) { 880e7ea6884Snbeams CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64; 8817d8d0e25Snbeams elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1; 8827d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 8837d8d0e25Snbeams ? 1 : 0 ); 8847d8d0e25Snbeams CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar); 8857d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, 1, 8867d8d0e25Snbeams elemsPerBlock, sharedMem, 8877d8d0e25Snbeams interpargs); CeedChk(ierr); 8887d8d0e25Snbeams } else if (dim == 2) { 8899e31c45bSnbeams // Check if required threads is small enough to do multiple elems 8909e31c45bSnbeams const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1); 8917d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 8927d8d0e25Snbeams ? 1 : 0 ); 8937d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 8947d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d, 8957d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 8967d8d0e25Snbeams interpargs); CeedChk(ierr); 8977d8d0e25Snbeams } else if (dim == 3) { 8987d8d0e25Snbeams CeedInt elemsPerBlock = 1; 8997d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 9007d8d0e25Snbeams ? 1 : 0 ); 9017d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 9027d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d, 9037d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 9047d8d0e25Snbeams interpargs); CeedChk(ierr); 9057d8d0e25Snbeams } 9067d8d0e25Snbeams } break; 9077d8d0e25Snbeams case CEED_EVAL_GRAD: { 9087d8d0e25Snbeams CeedInt P1d, Q1d; 9099e31c45bSnbeams CeedInt blksize = data->blksizes[1]; 9107d8d0e25Snbeams ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr); 9117d8d0e25Snbeams ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr); 9127d8d0e25Snbeams CeedInt thread1d = CeedIntMax(Q1d, P1d); 9137d8d0e25Snbeams ierr = CeedHipInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d, 9147d8d0e25Snbeams Q1d, &data->c_B, &data->c_G); 9157d8d0e25Snbeams CeedChk(ierr); 9167d8d0e25Snbeams void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, 9177d8d0e25Snbeams &data->c_G, &d_u, &d_v 9187d8d0e25Snbeams }; 9197d8d0e25Snbeams if (dim == 1) { 920e7ea6884Snbeams CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64; 9217d8d0e25Snbeams elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1; 9227d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 9237d8d0e25Snbeams ? 1 : 0 ); 9247d8d0e25Snbeams CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar); 9257d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, 1, 9267d8d0e25Snbeams elemsPerBlock, sharedMem, gradargs); 9277d8d0e25Snbeams CeedChk(ierr); 9287d8d0e25Snbeams } else if (dim == 2) { 9299e31c45bSnbeams // Check if required threads is small enough to do multiple elems 9309e31c45bSnbeams const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1); 9317d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 9327d8d0e25Snbeams ? 1 : 0 ); 9337d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 9347d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d, 9357d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 9367d8d0e25Snbeams gradargs); CeedChk(ierr); 9377d8d0e25Snbeams } else if (dim == 3) { 9387d8d0e25Snbeams CeedInt elemsPerBlock = 1; 9397d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 9407d8d0e25Snbeams ? 1 : 0 ); 9417d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 9427d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d, 9437d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 9447d8d0e25Snbeams gradargs); CeedChk(ierr); 9457d8d0e25Snbeams } 9467d8d0e25Snbeams } break; 9477d8d0e25Snbeams case CEED_EVAL_WEIGHT: { 9487d8d0e25Snbeams CeedInt Q1d; 9499e31c45bSnbeams CeedInt blksize = data->blksizes[2]; 9507d8d0e25Snbeams ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr); 9517d8d0e25Snbeams void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v}; 9527d8d0e25Snbeams if (dim == 1) { 9539e31c45bSnbeams const CeedInt optElems = blksize/Q1d; 9547d8d0e25Snbeams const CeedInt elemsPerBlock = optElems>0?optElems:1; 9557d8d0e25Snbeams const CeedInt gridsize = nelem/elemsPerBlock + ( ( 9567d8d0e25Snbeams nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 9577d8d0e25Snbeams ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, 9587d8d0e25Snbeams elemsPerBlock, 1, weightargs); 9597d8d0e25Snbeams CeedChk(ierr); 9607d8d0e25Snbeams } else if (dim == 2) { 9619e31c45bSnbeams const CeedInt optElems = blksize/(Q1d*Q1d); 9627d8d0e25Snbeams const CeedInt elemsPerBlock = optElems>0?optElems:1; 9637d8d0e25Snbeams const CeedInt gridsize = nelem/elemsPerBlock + ( ( 9647d8d0e25Snbeams nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 9657d8d0e25Snbeams ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, 9667d8d0e25Snbeams elemsPerBlock, weightargs); 9677d8d0e25Snbeams CeedChk(ierr); 9687d8d0e25Snbeams } else if (dim == 3) { 9697d8d0e25Snbeams const CeedInt gridsize = nelem; 9707d8d0e25Snbeams ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, Q1d, 9717d8d0e25Snbeams weightargs); 9727d8d0e25Snbeams CeedChk(ierr); 9737d8d0e25Snbeams } 9747d8d0e25Snbeams } break; 9757d8d0e25Snbeams // LCOV_EXCL_START 9767d8d0e25Snbeams // Evaluate the divergence to/from the quadrature points 9777d8d0e25Snbeams case CEED_EVAL_DIV: 9787d8d0e25Snbeams return CeedError(ceed, 1, "CEED_EVAL_DIV not supported"); 9797d8d0e25Snbeams // Evaluate the curl to/from the quadrature points 9807d8d0e25Snbeams case CEED_EVAL_CURL: 9817d8d0e25Snbeams return CeedError(ceed, 1, "CEED_EVAL_CURL not supported"); 9827d8d0e25Snbeams // Take no action, BasisApply should not have been called 9837d8d0e25Snbeams case CEED_EVAL_NONE: 9847d8d0e25Snbeams return CeedError(ceed, 1, 9857d8d0e25Snbeams "CEED_EVAL_NONE does not make sense in this context"); 9867d8d0e25Snbeams // LCOV_EXCL_STOP 9877d8d0e25Snbeams } 9887d8d0e25Snbeams 9897d8d0e25Snbeams // Restore vectors 9907d8d0e25Snbeams if (emode != CEED_EVAL_WEIGHT) { 9917d8d0e25Snbeams ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr); 9927d8d0e25Snbeams } 9937d8d0e25Snbeams ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr); 9947d8d0e25Snbeams return 0; 9957d8d0e25Snbeams } 9967d8d0e25Snbeams 9977d8d0e25Snbeams //------------------------------------------------------------------------------ 9987d8d0e25Snbeams // Destroy basis 9997d8d0e25Snbeams //------------------------------------------------------------------------------ 10007d8d0e25Snbeams static int CeedBasisDestroy_Hip_shared(CeedBasis basis) { 10017d8d0e25Snbeams int ierr; 10027d8d0e25Snbeams Ceed ceed; 10037d8d0e25Snbeams ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr); 10047d8d0e25Snbeams 10057d8d0e25Snbeams CeedBasis_Hip_shared *data; 10067d8d0e25Snbeams ierr = CeedBasisGetData(basis, &data); CeedChk(ierr); 10077d8d0e25Snbeams 10087d8d0e25Snbeams CeedChk_Hip(ceed, hipModuleUnload(data->module)); 10097d8d0e25Snbeams 10107d8d0e25Snbeams ierr = hipFree(data->d_qweight1d); CeedChk_Hip(ceed, ierr); 10117d8d0e25Snbeams ierr = hipFree(data->d_interp1d); CeedChk_Hip(ceed, ierr); 10127d8d0e25Snbeams ierr = hipFree(data->d_grad1d); CeedChk_Hip(ceed, ierr); 10137d8d0e25Snbeams ierr = hipFree(data->d_collograd1d); CeedChk_Hip(ceed, ierr); 10147d8d0e25Snbeams 10157d8d0e25Snbeams ierr = CeedFree(&data); CeedChk(ierr); 10167d8d0e25Snbeams 10177d8d0e25Snbeams return 0; 10187d8d0e25Snbeams } 10197d8d0e25Snbeams 10207d8d0e25Snbeams //------------------------------------------------------------------------------ 10217d8d0e25Snbeams // Create tensor basis 10227d8d0e25Snbeams //------------------------------------------------------------------------------ 10237d8d0e25Snbeams int CeedBasisCreateTensorH1_Hip_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d, 10247d8d0e25Snbeams const CeedScalar *interp1d, 10257d8d0e25Snbeams const CeedScalar *grad1d, 10267d8d0e25Snbeams const CeedScalar *qref1d, 10277d8d0e25Snbeams const CeedScalar *qweight1d, 10287d8d0e25Snbeams CeedBasis basis) { 10297d8d0e25Snbeams int ierr; 10307d8d0e25Snbeams Ceed ceed; 10317d8d0e25Snbeams ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr); 10327d8d0e25Snbeams CeedBasis_Hip_shared *data; 10337d8d0e25Snbeams ierr = CeedCalloc(1, &data); CeedChk(ierr); 10347d8d0e25Snbeams 10357d8d0e25Snbeams // Copy basis data to GPU 10367d8d0e25Snbeams const CeedInt qBytes = Q1d * sizeof(CeedScalar); 10377d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Hip(ceed, ierr); 10387d8d0e25Snbeams ierr = hipMemcpy(data->d_qweight1d, qweight1d, qBytes, 10397d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 10407d8d0e25Snbeams 10417d8d0e25Snbeams const CeedInt iBytes = qBytes * P1d; 10427d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Hip(ceed, ierr); 10437d8d0e25Snbeams ierr = hipMemcpy(data->d_interp1d, interp1d, iBytes, 10447d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 10457d8d0e25Snbeams 10467d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Hip(ceed, ierr); 10477d8d0e25Snbeams ierr = hipMemcpy(data->d_grad1d, grad1d, iBytes, 10487d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 10497d8d0e25Snbeams 10507d8d0e25Snbeams // Compute collocated gradient and copy to GPU 10517d8d0e25Snbeams data->d_collograd1d = NULL; 10527d8d0e25Snbeams if (dim == 3 && Q1d >= P1d) { 10537d8d0e25Snbeams CeedScalar *collograd1d; 10547d8d0e25Snbeams ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChk(ierr); 10557d8d0e25Snbeams ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChk(ierr); 10567d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_collograd1d, qBytes * Q1d); 10577d8d0e25Snbeams CeedChk_Hip(ceed, ierr); 10587d8d0e25Snbeams ierr = hipMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d, 10597d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 10607d8d0e25Snbeams ierr = CeedFree(&collograd1d); CeedChk(ierr); 10617d8d0e25Snbeams } 10627d8d0e25Snbeams 10639e31c45bSnbeams // Set number of threads per block for basis kernels 10647d8d0e25Snbeams CeedInt ncomp; 10657d8d0e25Snbeams ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr); 10669e31c45bSnbeams ierr = ComputeBasisThreadBlockSizes(dim, P1d, Q1d, ncomp, data->blksizes); 10679e31c45bSnbeams CeedChk(ierr); 10689e31c45bSnbeams 10699e31c45bSnbeams // Compile basis kernels 10709e31c45bSnbeams ierr = CeedCompileHip(ceed, kernelsShared, &data->module, 11, 10717d8d0e25Snbeams "Q1D", Q1d, 10727d8d0e25Snbeams "P1D", P1d, 10737d8d0e25Snbeams "T1D", CeedIntMax(Q1d, P1d), 10747d8d0e25Snbeams "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ? 10757d8d0e25Snbeams Q1d : P1d, dim), 10767d8d0e25Snbeams "BASIS_DIM", dim, 10777d8d0e25Snbeams "BASIS_NCOMP", ncomp, 10787d8d0e25Snbeams "BASIS_ELEMSIZE", CeedIntPow(P1d, dim), 10799e31c45bSnbeams "BASIS_NQPT", CeedIntPow(Q1d, dim), 10809e31c45bSnbeams "INTERP_BLKSIZE", data->blksizes[0], 10819e31c45bSnbeams "GRAD_BLKSIZE", data->blksizes[1], 10829e31c45bSnbeams "WEIGHT_BLKSIZE", data->blksizes[2] 10837d8d0e25Snbeams ); CeedChk(ierr); 10847d8d0e25Snbeams ierr = CeedGetKernelHip(ceed, data->module, "interp", &data->interp); 10857d8d0e25Snbeams CeedChk(ierr); 10867d8d0e25Snbeams ierr = CeedGetKernelHip(ceed, data->module, "grad", &data->grad); 10877d8d0e25Snbeams CeedChk(ierr); 10887d8d0e25Snbeams ierr = CeedGetKernelHip(ceed, data->module, "weight", &data->weight); 10897d8d0e25Snbeams CeedChk(ierr); 10907d8d0e25Snbeams 10917d8d0e25Snbeams ierr = CeedBasisSetData(basis, data); CeedChk(ierr); 10927d8d0e25Snbeams 10937d8d0e25Snbeams // Register backend functions 10947d8d0e25Snbeams ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply", 10957d8d0e25Snbeams CeedBasisApplyTensor_Hip_shared); 10967d8d0e25Snbeams CeedChk(ierr); 10977d8d0e25Snbeams ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", 10987d8d0e25Snbeams CeedBasisDestroy_Hip_shared); CeedChk(ierr); 10997d8d0e25Snbeams return 0; 11007d8d0e25Snbeams } 11017d8d0e25Snbeams //------------------------------------------------------------------------------ 1102