17d8d0e25Snbeams // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC. 27d8d0e25Snbeams // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707. 37d8d0e25Snbeams // All Rights reserved. See files LICENSE and NOTICE for details. 47d8d0e25Snbeams // 57d8d0e25Snbeams // This file is part of CEED, a collection of benchmarks, miniapps, software 67d8d0e25Snbeams // libraries and APIs for efficient high-order finite element and spectral 77d8d0e25Snbeams // element discretizations for exascale applications. For more information and 87d8d0e25Snbeams // source code availability see http://github.com/ceed. 97d8d0e25Snbeams // 107d8d0e25Snbeams // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 117d8d0e25Snbeams // a collaborative effort of two U.S. Department of Energy organizations (Office 127d8d0e25Snbeams // of Science and the National Nuclear Security Administration) responsible for 137d8d0e25Snbeams // the planning and preparation of a capable exascale ecosystem, including 147d8d0e25Snbeams // software, applications, hardware, advanced system engineering and early 157d8d0e25Snbeams // testbed platforms, in support of the nation's exascale computing imperative. 167d8d0e25Snbeams 173d576824SJeremy L Thompson #include <ceed.h> 183d576824SJeremy L Thompson #include <ceed-backend.h> 193d576824SJeremy L Thompson #include <hip/hip_runtime.h> 203d576824SJeremy L Thompson #include <stddef.h> 217d8d0e25Snbeams #include "ceed-hip-shared.h" 223d576824SJeremy L Thompson #include "../hip/ceed-hip.h" 237d8d0e25Snbeams #include "../hip/ceed-hip-compile.h" 247d8d0e25Snbeams 257d8d0e25Snbeams //------------------------------------------------------------------------------ 267d8d0e25Snbeams // Shared mem kernels 277d8d0e25Snbeams //------------------------------------------------------------------------------ 287d8d0e25Snbeams // *INDENT-OFF* 297d8d0e25Snbeams static const char *kernelsShared = QUOTE( 307d8d0e25Snbeams 317d8d0e25Snbeams //------------------------------------------------------------------------------ 327d8d0e25Snbeams // Sum input into output 337d8d0e25Snbeams //------------------------------------------------------------------------------ 347d8d0e25Snbeams inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) { 357d8d0e25Snbeams for (int i = 0; i < P1D; i++) 367d8d0e25Snbeams r_V[i] += r_U[i]; 377d8d0e25Snbeams } 387d8d0e25Snbeams 397d8d0e25Snbeams //------------------------------------------------------------------------------ 409dd88646Snbeams // Load matrices for basis actions 419dd88646Snbeams //------------------------------------------------------------------------------ 429dd88646Snbeams inline __device__ void loadMatrix(const CeedScalar* d_B, CeedScalar* B) { 439dd88646Snbeams CeedInt tid = threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.y*blockDim.x; 449dd88646Snbeams for (CeedInt i = tid; i < P1D*Q1D; i += blockDim.x*blockDim.y*blockDim.z) 459dd88646Snbeams B[i] = d_B[i]; 469dd88646Snbeams } 479dd88646Snbeams 489dd88646Snbeams //------------------------------------------------------------------------------ 497d8d0e25Snbeams // 1D 507d8d0e25Snbeams //------------------------------------------------------------------------------ 517d8d0e25Snbeams 527d8d0e25Snbeams //------------------------------------------------------------------------------ 537d8d0e25Snbeams // Read DoFs 547d8d0e25Snbeams //------------------------------------------------------------------------------ 557d8d0e25Snbeams inline __device__ void readDofs1d(const int elem, const int tidx, 567d8d0e25Snbeams const int tidy, const int tidz,const int comp, 577d8d0e25Snbeams const int nelem, const CeedScalar *d_U, 587d8d0e25Snbeams CeedScalar *slice) { 597d8d0e25Snbeams for (int i = 0; i < P1D; i++) 607d8d0e25Snbeams slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem]; 617d8d0e25Snbeams for (int i = P1D; i < Q1D; i++) 627d8d0e25Snbeams slice[i + tidz*T1D] = 0.0; 637d8d0e25Snbeams } 647d8d0e25Snbeams 657d8d0e25Snbeams //------------------------------------------------------------------------------ 667d8d0e25Snbeams // Write DoFs 677d8d0e25Snbeams //------------------------------------------------------------------------------ 687d8d0e25Snbeams inline __device__ void writeDofs1d(const int elem, const int tidx, 697d8d0e25Snbeams const int tidy, const int comp, 707d8d0e25Snbeams const int nelem, const CeedScalar &r_V, 717d8d0e25Snbeams CeedScalar *d_V) { 727d8d0e25Snbeams if (tidx<P1D) 737d8d0e25Snbeams d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V; 747d8d0e25Snbeams } 757d8d0e25Snbeams 767d8d0e25Snbeams //------------------------------------------------------------------------------ 777d8d0e25Snbeams // Read quadrature point data 787d8d0e25Snbeams //------------------------------------------------------------------------------ 797d8d0e25Snbeams inline __device__ void readQuads1d(const int elem, const int tidx, 807d8d0e25Snbeams const int tidy, const int tidz, const int comp, 817d8d0e25Snbeams const int dim, const int nelem, 827d8d0e25Snbeams const CeedScalar *d_U, CeedScalar *slice) { 837d8d0e25Snbeams for (int i = 0; i < Q1D; i++) 847d8d0e25Snbeams slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem + 857d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D]; 867d8d0e25Snbeams for (int i = Q1D; i < P1D; i++) 877d8d0e25Snbeams slice[i + tidz*T1D] = 0.0; 887d8d0e25Snbeams } 897d8d0e25Snbeams 907d8d0e25Snbeams //------------------------------------------------------------------------------ 917d8d0e25Snbeams // Write quadrature point data 927d8d0e25Snbeams //------------------------------------------------------------------------------ 937d8d0e25Snbeams inline __device__ void writeQuads1d(const int elem, const int tidx, 947d8d0e25Snbeams const int tidy, const int comp, 957d8d0e25Snbeams const int dim, const int nelem, 967d8d0e25Snbeams const CeedScalar &r_V, CeedScalar *d_V) { 977d8d0e25Snbeams if (tidx<Q1D) 987d8d0e25Snbeams d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V; 997d8d0e25Snbeams } 1007d8d0e25Snbeams 1017d8d0e25Snbeams //------------------------------------------------------------------------------ 1027d8d0e25Snbeams // 1D tensor contraction 1037d8d0e25Snbeams //------------------------------------------------------------------------------ 1047d8d0e25Snbeams inline __device__ void ContractX1d(CeedScalar *slice, const int tidx, 1057d8d0e25Snbeams const int tidy, const int tidz, 1067d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, 1077d8d0e25Snbeams CeedScalar &V) { 1087d8d0e25Snbeams V = 0.0; 1097d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 1107d8d0e25Snbeams V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction 1117d8d0e25Snbeams } 1127d8d0e25Snbeams 1137d8d0e25Snbeams //------------------------------------------------------------------------------ 1147d8d0e25Snbeams // 1D transpose tensor contraction 1157d8d0e25Snbeams //------------------------------------------------------------------------------ 1167d8d0e25Snbeams inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx, 1177d8d0e25Snbeams const int tidy, const int tidz, 1187d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 1197d8d0e25Snbeams V = 0.0; 1207d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 1217d8d0e25Snbeams V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction 1227d8d0e25Snbeams } 1237d8d0e25Snbeams 1247d8d0e25Snbeams //------------------------------------------------------------------------------ 1257d8d0e25Snbeams // 1D interpolate to quadrature points 1267d8d0e25Snbeams //------------------------------------------------------------------------------ 1277d8d0e25Snbeams inline __device__ void interp1d(const CeedInt nelem, const int transpose, 128e9132427Snbeams const CeedScalar *s_B, 1297d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 1307d8d0e25Snbeams CeedScalar *__restrict__ d_V, 1317d8d0e25Snbeams CeedScalar *slice) { 1327d8d0e25Snbeams CeedScalar r_V; 1337d8d0e25Snbeams CeedScalar r_t; 1347d8d0e25Snbeams 1357d8d0e25Snbeams const int tidx = threadIdx.x; 1367d8d0e25Snbeams const int tidy = threadIdx.y; 1377d8d0e25Snbeams const int tidz = threadIdx.z; 1387d8d0e25Snbeams 1397d8d0e25Snbeams 1407d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 1417d8d0e25Snbeams elem += gridDim.x*blockDim.z) { 1427d8d0e25Snbeams for (int comp = 0; comp < BASIS_NCOMP; comp++) { 1437d8d0e25Snbeams if (!transpose) { 1447d8d0e25Snbeams readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 145e9132427Snbeams ContractX1d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 1467d8d0e25Snbeams writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 1477d8d0e25Snbeams } else { 1487d8d0e25Snbeams readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice); 149e9132427Snbeams ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 1507d8d0e25Snbeams writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 1517d8d0e25Snbeams } 1527d8d0e25Snbeams } 1537d8d0e25Snbeams } 1547d8d0e25Snbeams } 1557d8d0e25Snbeams 1567d8d0e25Snbeams //------------------------------------------------------------------------------ 1577d8d0e25Snbeams // 1D derivatives at quadrature points 1587d8d0e25Snbeams //------------------------------------------------------------------------------ 1597d8d0e25Snbeams inline __device__ void grad1d(const CeedInt nelem, const int transpose, 160e9132427Snbeams const CeedScalar *s_B, const CeedScalar *s_G, 1617d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 1627d8d0e25Snbeams CeedScalar *__restrict__ d_V, 1637d8d0e25Snbeams CeedScalar *slice) { 1647d8d0e25Snbeams CeedScalar r_U; 1657d8d0e25Snbeams CeedScalar r_V; 1667d8d0e25Snbeams 1677d8d0e25Snbeams const int tidx = threadIdx.x; 1687d8d0e25Snbeams const int tidy = threadIdx.y; 1697d8d0e25Snbeams const int tidz = threadIdx.z; 1707d8d0e25Snbeams int dim; 1717d8d0e25Snbeams 1727d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 1737d8d0e25Snbeams elem += gridDim.x*blockDim.z) { 1747d8d0e25Snbeams for(int comp = 0; comp < BASIS_NCOMP; comp++) { 1757d8d0e25Snbeams if (!transpose) { 1767d8d0e25Snbeams readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 177e9132427Snbeams ContractX1d(slice, tidx, tidy, tidz, r_U, s_G, r_V); 1787d8d0e25Snbeams dim = 0; 1797d8d0e25Snbeams writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 1807d8d0e25Snbeams } else { 1817d8d0e25Snbeams dim = 0; 1827d8d0e25Snbeams readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice); 183e9132427Snbeams ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, s_G, r_V); 1847d8d0e25Snbeams writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 1857d8d0e25Snbeams } 1867d8d0e25Snbeams } 1877d8d0e25Snbeams } 1887d8d0e25Snbeams } 1897d8d0e25Snbeams 1907d8d0e25Snbeams //------------------------------------------------------------------------------ 1917d8d0e25Snbeams // 1D Quadrature weights 1927d8d0e25Snbeams //------------------------------------------------------------------------------ 1937d8d0e25Snbeams __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d, 1947d8d0e25Snbeams CeedScalar *w) { 1957d8d0e25Snbeams const int tid = threadIdx.x; 1967d8d0e25Snbeams const CeedScalar weight = qweight1d[tid]; 1977d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem; 1987d8d0e25Snbeams elem += gridDim.x*blockDim.y) { 1997d8d0e25Snbeams const int ind = elem*Q1D + tid; 2007d8d0e25Snbeams w[ind] = weight; 2017d8d0e25Snbeams } 2027d8d0e25Snbeams } 2037d8d0e25Snbeams 2047d8d0e25Snbeams //------------------------------------------------------------------------------ 2057d8d0e25Snbeams // 2D 2067d8d0e25Snbeams //------------------------------------------------------------------------------ 2077d8d0e25Snbeams 2087d8d0e25Snbeams //------------------------------------------------------------------------------ 2097d8d0e25Snbeams // Read DoFs 2107d8d0e25Snbeams //------------------------------------------------------------------------------ 2117d8d0e25Snbeams inline __device__ void readDofs2d(const int elem, const int tidx, 2127d8d0e25Snbeams const int tidy, const int comp, 2137d8d0e25Snbeams const int nelem, const CeedScalar *d_U, 2147d8d0e25Snbeams CeedScalar &U) { 2157d8d0e25Snbeams U = (tidx<P1D && tidy<P1D) ? 2167d8d0e25Snbeams d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0; 2177d8d0e25Snbeams } 2187d8d0e25Snbeams 2197d8d0e25Snbeams //------------------------------------------------------------------------------ 2207d8d0e25Snbeams // Write DoFs 2217d8d0e25Snbeams //------------------------------------------------------------------------------ 2227d8d0e25Snbeams inline __device__ void writeDofs2d(const int elem, const int tidx, 2237d8d0e25Snbeams const int tidy, const int comp, 2247d8d0e25Snbeams const int nelem, const CeedScalar &r_V, 2257d8d0e25Snbeams CeedScalar *d_V) { 2267d8d0e25Snbeams if (tidx<P1D && tidy<P1D) 2277d8d0e25Snbeams d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V; 2287d8d0e25Snbeams } 2297d8d0e25Snbeams 2307d8d0e25Snbeams //------------------------------------------------------------------------------ 2317d8d0e25Snbeams // Read quadrature point data 2327d8d0e25Snbeams //------------------------------------------------------------------------------ 2337d8d0e25Snbeams inline __device__ void readQuads2d(const int elem, const int tidx, 2347d8d0e25Snbeams const int tidy, const int comp, 2357d8d0e25Snbeams const int dim, const int nelem, 2367d8d0e25Snbeams const CeedScalar *d_U, CeedScalar &U ) { 2377d8d0e25Snbeams U = (tidx<Q1D && tidy<Q1D) ? 2387d8d0e25Snbeams d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 2397d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0; 2407d8d0e25Snbeams } 2417d8d0e25Snbeams 2427d8d0e25Snbeams //------------------------------------------------------------------------------ 2437d8d0e25Snbeams // Write quadrature point data 2447d8d0e25Snbeams //------------------------------------------------------------------------------ 2457d8d0e25Snbeams inline __device__ void writeQuads2d(const int elem, const int tidx, 2467d8d0e25Snbeams const int tidy, const int comp, 2477d8d0e25Snbeams const int dim, const int nelem, 2487d8d0e25Snbeams const CeedScalar &r_V, CeedScalar *d_V) { 2497d8d0e25Snbeams if (tidx<Q1D && tidy<Q1D) 2507d8d0e25Snbeams d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 2517d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V; 2527d8d0e25Snbeams } 2537d8d0e25Snbeams 2547d8d0e25Snbeams //------------------------------------------------------------------------------ 2557d8d0e25Snbeams // 2D tensor contraction x 2567d8d0e25Snbeams //------------------------------------------------------------------------------ 2577d8d0e25Snbeams inline __device__ void ContractX2d(CeedScalar *slice, const int tidx, 2587d8d0e25Snbeams const int tidy, const int tidz, 2597d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, 2607d8d0e25Snbeams CeedScalar &V) { 2617d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 2627d8d0e25Snbeams __syncthreads(); 2637d8d0e25Snbeams V = 0.0; 2647d8d0e25Snbeams if (tidx < Q1D) 2657d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 2667d8d0e25Snbeams V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 2677d8d0e25Snbeams __syncthreads(); 2687d8d0e25Snbeams } 2697d8d0e25Snbeams 2707d8d0e25Snbeams //------------------------------------------------------------------------------ 2717d8d0e25Snbeams // 2D tensor contraction y 2727d8d0e25Snbeams //------------------------------------------------------------------------------ 2737d8d0e25Snbeams inline __device__ void ContractY2d(CeedScalar *slice, const int tidx, 2747d8d0e25Snbeams const int tidy, const int tidz, 2757d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, 2767d8d0e25Snbeams CeedScalar &V) { 2777d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 2787d8d0e25Snbeams __syncthreads(); 2797d8d0e25Snbeams V = 0.0; 2807d8d0e25Snbeams if (tidy < Q1D) 2817d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 2827d8d0e25Snbeams V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 2837d8d0e25Snbeams __syncthreads(); 2847d8d0e25Snbeams } 2857d8d0e25Snbeams 2867d8d0e25Snbeams //------------------------------------------------------------------------------ 2877d8d0e25Snbeams // 2D transpose tensor contraction y 2887d8d0e25Snbeams //------------------------------------------------------------------------------ 2897d8d0e25Snbeams inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx, 2907d8d0e25Snbeams const int tidy, const int tidz, 2917d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 2927d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 2937d8d0e25Snbeams __syncthreads(); 2947d8d0e25Snbeams V = 0.0; 2957d8d0e25Snbeams if (tidy < P1D) 2967d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 2977d8d0e25Snbeams V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 2987d8d0e25Snbeams __syncthreads(); 2997d8d0e25Snbeams } 3007d8d0e25Snbeams 3017d8d0e25Snbeams //------------------------------------------------------------------------------ 3027d8d0e25Snbeams // 2D transpose tensor contraction x 3037d8d0e25Snbeams //------------------------------------------------------------------------------ 3047d8d0e25Snbeams inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx, 3057d8d0e25Snbeams const int tidy, const int tidz, 3067d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 3077d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 3087d8d0e25Snbeams __syncthreads(); 3097d8d0e25Snbeams V = 0.0; 3107d8d0e25Snbeams if (tidx < P1D) 3117d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 3127d8d0e25Snbeams V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 3137d8d0e25Snbeams __syncthreads(); 3147d8d0e25Snbeams } 3157d8d0e25Snbeams 3167d8d0e25Snbeams //------------------------------------------------------------------------------ 3177d8d0e25Snbeams // 2D interpolate to quadrature points 3187d8d0e25Snbeams //------------------------------------------------------------------------------ 3197d8d0e25Snbeams inline __device__ void interp2d(const CeedInt nelem, const int transpose, 320e9132427Snbeams const CeedScalar *s_B, 3217d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 3227d8d0e25Snbeams CeedScalar *__restrict__ d_V, 3237d8d0e25Snbeams CeedScalar *slice) { 3247d8d0e25Snbeams CeedScalar r_V; 3257d8d0e25Snbeams CeedScalar r_t; 3267d8d0e25Snbeams 3277d8d0e25Snbeams const int tidx = threadIdx.x; 3287d8d0e25Snbeams const int tidy = threadIdx.y; 3297d8d0e25Snbeams const int tidz = threadIdx.z; 3307d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 3317d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 3327d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 3337d8d0e25Snbeams 3347d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 3357d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 3367d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 3377d8d0e25Snbeams r_V = 0.0; 3387d8d0e25Snbeams r_t = 0.0; 3397d8d0e25Snbeams if (!transpose) { 3407d8d0e25Snbeams readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V); 341e9132427Snbeams ContractX2d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 342e9132427Snbeams ContractY2d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 3437d8d0e25Snbeams writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 3447d8d0e25Snbeams } else { 3457d8d0e25Snbeams readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 346e9132427Snbeams ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 347e9132427Snbeams ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 3487d8d0e25Snbeams writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 3497d8d0e25Snbeams } 3507d8d0e25Snbeams } 3517d8d0e25Snbeams } 3527d8d0e25Snbeams 3537d8d0e25Snbeams //------------------------------------------------------------------------------ 3547d8d0e25Snbeams // 2D derivatives at quadrature points 3557d8d0e25Snbeams //------------------------------------------------------------------------------ 3567d8d0e25Snbeams inline __device__ void grad2d(const CeedInt nelem, const int transpose, 357e9132427Snbeams const CeedScalar *s_B, const CeedScalar *s_G, 3587d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 3597d8d0e25Snbeams CeedScalar *__restrict__ d_V, CeedScalar *slice) { 3607d8d0e25Snbeams CeedScalar r_U; 3617d8d0e25Snbeams CeedScalar r_V; 3627d8d0e25Snbeams CeedScalar r_t; 3637d8d0e25Snbeams 3647d8d0e25Snbeams const int tidx = threadIdx.x; 3657d8d0e25Snbeams const int tidy = threadIdx.y; 3667d8d0e25Snbeams const int tidz = threadIdx.z; 3677d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 3687d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 3697d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 3707d8d0e25Snbeams int dim; 3717d8d0e25Snbeams 3727d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 3737d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 3747d8d0e25Snbeams if (!transpose) { 3757d8d0e25Snbeams readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U); 376e9132427Snbeams ContractX2d(slice, tidx, tidy, tidz, r_U, s_G, r_t); 377e9132427Snbeams ContractY2d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 3787d8d0e25Snbeams dim = 0; 3797d8d0e25Snbeams writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 380e9132427Snbeams ContractX2d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 381e9132427Snbeams ContractY2d(slice, tidx, tidy, tidz, r_t, s_G, r_V); 3827d8d0e25Snbeams dim = 1; 3837d8d0e25Snbeams writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 3847d8d0e25Snbeams } else { 3857d8d0e25Snbeams dim = 0; 3867d8d0e25Snbeams readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 387e9132427Snbeams ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 388e9132427Snbeams ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, s_G, r_V); 3897d8d0e25Snbeams dim = 1; 3907d8d0e25Snbeams readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 391e9132427Snbeams ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, s_G, r_t); 392e9132427Snbeams ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, s_B, r_U); 3937d8d0e25Snbeams r_V += r_U; 3947d8d0e25Snbeams writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 3957d8d0e25Snbeams } 3967d8d0e25Snbeams } 3977d8d0e25Snbeams } 3987d8d0e25Snbeams 3997d8d0e25Snbeams //------------------------------------------------------------------------------ 4007d8d0e25Snbeams // 2D quadrature weights 4017d8d0e25Snbeams //------------------------------------------------------------------------------ 4027d8d0e25Snbeams __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d, 4037d8d0e25Snbeams CeedScalar *w) { 4047d8d0e25Snbeams const int i = threadIdx.x; 4057d8d0e25Snbeams const int j = threadIdx.y; 4067d8d0e25Snbeams const CeedScalar weight = qweight1d[i]*qweight1d[j]; 4077d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 4087d8d0e25Snbeams elem += gridDim.x*blockDim.z) { 4097d8d0e25Snbeams const int ind = elem*Q1D*Q1D + i + j*Q1D; 4107d8d0e25Snbeams w[ind] = weight; 4117d8d0e25Snbeams } 4127d8d0e25Snbeams } 4137d8d0e25Snbeams 4147d8d0e25Snbeams //------------------------------------------------------------------------------ 4157d8d0e25Snbeams // 3D 4167d8d0e25Snbeams //------------------------------------------------------------------------------ 4177d8d0e25Snbeams 4187d8d0e25Snbeams //------------------------------------------------------------------------------ 4197d8d0e25Snbeams // Read DoFs 4207d8d0e25Snbeams //------------------------------------------------------------------------------ 4217d8d0e25Snbeams inline __device__ void readDofs3d(const int elem, const int tidx, 4227d8d0e25Snbeams const int tidy, const int comp, 4237d8d0e25Snbeams const int nelem, const CeedScalar *d_U, 4247d8d0e25Snbeams CeedScalar *r_U) { 4257d8d0e25Snbeams for (int i = 0; i < P1D; i++) 4267d8d0e25Snbeams r_U[i] = (tidx < P1D && tidy < P1D) ? 4277d8d0e25Snbeams d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D + 4287d8d0e25Snbeams comp*P1D*P1D*P1D*nelem] : 0.0; 4297d8d0e25Snbeams for (int i = P1D; i < Q1D; i++) 4307d8d0e25Snbeams r_U[i] = 0.0; 4317d8d0e25Snbeams } 4327d8d0e25Snbeams 4337d8d0e25Snbeams //------------------------------------------------------------------------------ 4347d8d0e25Snbeams // Write DoFs 4357d8d0e25Snbeams //------------------------------------------------------------------------------ 4367d8d0e25Snbeams inline __device__ void writeDofs3d(const int elem, const int tidx, 4377d8d0e25Snbeams const int tidy, const int comp, 4387d8d0e25Snbeams const int nelem, const CeedScalar *r_V, 4397d8d0e25Snbeams CeedScalar *d_V) { 4407d8d0e25Snbeams if (tidx < P1D && tidy < P1D) { 4417d8d0e25Snbeams for (int i = 0; i < P1D; i++) 4427d8d0e25Snbeams d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D + 4437d8d0e25Snbeams comp*P1D*P1D*P1D*nelem] = r_V[i]; 4447d8d0e25Snbeams } 4457d8d0e25Snbeams } 4467d8d0e25Snbeams 4477d8d0e25Snbeams //------------------------------------------------------------------------------ 4487d8d0e25Snbeams // Read quadrature point data 4497d8d0e25Snbeams //------------------------------------------------------------------------------ 4507d8d0e25Snbeams inline __device__ void readQuads3d(const int elem, const int tidx, 4517d8d0e25Snbeams const int tidy, const int comp, 4527d8d0e25Snbeams const int dim, const int nelem, 4537d8d0e25Snbeams const CeedScalar *d_U, CeedScalar *r_U) { 4547d8d0e25Snbeams for (int i = 0; i < Q1D; i++) 4557d8d0e25Snbeams r_U[i] = (tidx < Q1D && tidy < Q1D) ? 4567d8d0e25Snbeams d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + 4577d8d0e25Snbeams comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0; 4587d8d0e25Snbeams for (int i = Q1D; i < P1D; i++) 4597d8d0e25Snbeams r_U[i] = 0.0; 4607d8d0e25Snbeams } 4617d8d0e25Snbeams 4627d8d0e25Snbeams //------------------------------------------------------------------------------ 4637d8d0e25Snbeams // Write quadrature point data 4647d8d0e25Snbeams //------------------------------------------------------------------------------ 4657d8d0e25Snbeams inline __device__ void writeQuads3d(const int elem, const int tidx, 4667d8d0e25Snbeams const int tidy, const int comp, 4677d8d0e25Snbeams const int dim, const int nelem, 4687d8d0e25Snbeams const CeedScalar *r_V, CeedScalar *d_V) { 4697d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) { 4707d8d0e25Snbeams for (int i = 0; i < Q1D; i++) 4717d8d0e25Snbeams d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem + 4727d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i]; 4737d8d0e25Snbeams } 4747d8d0e25Snbeams } 4757d8d0e25Snbeams 4767d8d0e25Snbeams //------------------------------------------------------------------------------ 4777d8d0e25Snbeams // 3D tensor contract x 4787d8d0e25Snbeams //------------------------------------------------------------------------------ 4797d8d0e25Snbeams inline __device__ void ContractX3d(CeedScalar *slice, const int tidx, 4807d8d0e25Snbeams const int tidy, const int tidz, 4817d8d0e25Snbeams const CeedScalar *U, 4827d8d0e25Snbeams const CeedScalar *B, 4837d8d0e25Snbeams CeedScalar *V) { 4847d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 4857d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 4867d8d0e25Snbeams __syncthreads(); 4877d8d0e25Snbeams V[k] = 0.0; 4887d8d0e25Snbeams if (tidx < Q1D && tidy < P1D) 4897d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 4907d8d0e25Snbeams V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 4917d8d0e25Snbeams __syncthreads(); 4927d8d0e25Snbeams } 4937d8d0e25Snbeams } 4947d8d0e25Snbeams 4957d8d0e25Snbeams //------------------------------------------------------------------------------ 4967d8d0e25Snbeams // 3D tensor contract y 4977d8d0e25Snbeams //------------------------------------------------------------------------------ 4987d8d0e25Snbeams inline __device__ void ContractY3d(CeedScalar *slice, const int tidx, 4997d8d0e25Snbeams const int tidy, const int tidz, 5007d8d0e25Snbeams const CeedScalar *U, 5017d8d0e25Snbeams const CeedScalar *B, 5027d8d0e25Snbeams CeedScalar *V) { 5037d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 5047d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 5057d8d0e25Snbeams __syncthreads(); 5067d8d0e25Snbeams V[k] = 0.0; 5077d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) 5087d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 5097d8d0e25Snbeams V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 5107d8d0e25Snbeams __syncthreads(); 5117d8d0e25Snbeams } 5127d8d0e25Snbeams } 5137d8d0e25Snbeams 5147d8d0e25Snbeams //------------------------------------------------------------------------------ 5157d8d0e25Snbeams // 3D tensor contract z 5167d8d0e25Snbeams //------------------------------------------------------------------------------ 5177d8d0e25Snbeams inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx, 5187d8d0e25Snbeams const int tidy, const int tidz, 5197d8d0e25Snbeams const CeedScalar *U, 5207d8d0e25Snbeams const CeedScalar *B, 5217d8d0e25Snbeams CeedScalar *V) { 5227d8d0e25Snbeams for (int k = 0; k < Q1D; ++k) { 5237d8d0e25Snbeams V[k] = 0.0; 5247d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) 5257d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 5267d8d0e25Snbeams V[k] += B[i + k*P1D] * U[i]; // Contract z direction 5277d8d0e25Snbeams } 5287d8d0e25Snbeams for (int k = Q1D; k < P1D; ++k) 5297d8d0e25Snbeams V[k] = 0.0; 5307d8d0e25Snbeams } 5317d8d0e25Snbeams 5327d8d0e25Snbeams //------------------------------------------------------------------------------ 5337d8d0e25Snbeams // 3D transpose tensor contract z 5347d8d0e25Snbeams //------------------------------------------------------------------------------ 5357d8d0e25Snbeams inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx, 5367d8d0e25Snbeams const int tidy, const int tidz, 5377d8d0e25Snbeams const CeedScalar *U, 5387d8d0e25Snbeams const CeedScalar *B, 5397d8d0e25Snbeams CeedScalar *V) { 5407d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 5417d8d0e25Snbeams V[k] = 0.0; 5427d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) 5437d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 5447d8d0e25Snbeams V[k] += B[k + i*P1D] * U[i]; // Contract z direction 5457d8d0e25Snbeams } 5467d8d0e25Snbeams for (int k = P1D; k < Q1D; ++k) 5477d8d0e25Snbeams V[k] = 0.0; 5487d8d0e25Snbeams } 5497d8d0e25Snbeams 5507d8d0e25Snbeams //------------------------------------------------------------------------------ 5517d8d0e25Snbeams // 3D transpose tensor contract y 5527d8d0e25Snbeams //------------------------------------------------------------------------------ 5537d8d0e25Snbeams inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx, 5547d8d0e25Snbeams const int tidy, const int tidz, 5557d8d0e25Snbeams const CeedScalar *U, 5567d8d0e25Snbeams const CeedScalar *B, 5577d8d0e25Snbeams CeedScalar *V) { 5587d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 5597d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 5607d8d0e25Snbeams __syncthreads(); 5617d8d0e25Snbeams V[k] = 0.0; 5627d8d0e25Snbeams if (tidx < Q1D && tidy < P1D) 5637d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 5647d8d0e25Snbeams V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 5657d8d0e25Snbeams __syncthreads(); 5667d8d0e25Snbeams } 5677d8d0e25Snbeams } 5687d8d0e25Snbeams 5697d8d0e25Snbeams //------------------------------------------------------------------------------ 5707d8d0e25Snbeams // 3D transpose tensor contract x 5717d8d0e25Snbeams //------------------------------------------------------------------------------ 5727d8d0e25Snbeams inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx, 5737d8d0e25Snbeams const int tidy, const int tidz, 5747d8d0e25Snbeams const CeedScalar *U, 5757d8d0e25Snbeams const CeedScalar *B, 5767d8d0e25Snbeams CeedScalar *V) { 5777d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 5787d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 5797d8d0e25Snbeams __syncthreads(); 5807d8d0e25Snbeams V[k] = 0.0; 5817d8d0e25Snbeams if (tidx < P1D && tidy < P1D) 5827d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 5837d8d0e25Snbeams V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 5847d8d0e25Snbeams __syncthreads(); 5857d8d0e25Snbeams } 5867d8d0e25Snbeams } 5877d8d0e25Snbeams 5887d8d0e25Snbeams //------------------------------------------------------------------------------ 5897d8d0e25Snbeams // 3D interpolate to quadrature points 5907d8d0e25Snbeams //------------------------------------------------------------------------------ 5917d8d0e25Snbeams inline __device__ void interp3d(const CeedInt nelem, const int transpose, 592e9132427Snbeams const CeedScalar *s_B, 5937d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 5947d8d0e25Snbeams CeedScalar *__restrict__ d_V, 5957d8d0e25Snbeams CeedScalar *slice) { 5967d8d0e25Snbeams CeedScalar r_V[T1D]; 5977d8d0e25Snbeams CeedScalar r_t[T1D]; 5987d8d0e25Snbeams 5997d8d0e25Snbeams const int tidx = threadIdx.x; 6007d8d0e25Snbeams const int tidy = threadIdx.y; 6017d8d0e25Snbeams const int tidz = threadIdx.z; 6027d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 6037d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 6047d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 6057d8d0e25Snbeams 6067d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 6077d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 6087d8d0e25Snbeams for (int i = 0; i < T1D; ++i) { 6097d8d0e25Snbeams r_V[i] = 0.0; 6107d8d0e25Snbeams r_t[i] = 0.0; 6117d8d0e25Snbeams } 6127d8d0e25Snbeams if (!transpose) { 6137d8d0e25Snbeams readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V); 614e9132427Snbeams ContractX3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 615e9132427Snbeams ContractY3d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 616e9132427Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 6177d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V); 6187d8d0e25Snbeams } else { 6197d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 620e9132427Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 621e9132427Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 622e9132427Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 6237d8d0e25Snbeams writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V); 6247d8d0e25Snbeams } 6257d8d0e25Snbeams } 6267d8d0e25Snbeams } 6277d8d0e25Snbeams 6287d8d0e25Snbeams //------------------------------------------------------------------------------ 6297d8d0e25Snbeams // 3D derivatives at quadrature points 6307d8d0e25Snbeams //------------------------------------------------------------------------------ 6317d8d0e25Snbeams inline __device__ void grad3d(const CeedInt nelem, const int transpose, 632e9132427Snbeams const CeedScalar *s_B, const CeedScalar *s_G, 6337d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 6347d8d0e25Snbeams CeedScalar *__restrict__ d_V, 6357d8d0e25Snbeams CeedScalar *slice) { 6367d8d0e25Snbeams // Use P1D for one of these 6377d8d0e25Snbeams CeedScalar r_U[T1D]; 6387d8d0e25Snbeams CeedScalar r_V[T1D]; 6397d8d0e25Snbeams CeedScalar r_t[T1D]; 6407d8d0e25Snbeams 6417d8d0e25Snbeams const int tidx = threadIdx.x; 6427d8d0e25Snbeams const int tidy = threadIdx.y; 6437d8d0e25Snbeams const int tidz = threadIdx.z; 6447d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 6457d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 6467d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 6477d8d0e25Snbeams int dim; 6487d8d0e25Snbeams 6497d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 6507d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 6517d8d0e25Snbeams for (int i = 0; i < T1D; ++i) { 6527d8d0e25Snbeams r_U[i] = 0.0; 6537d8d0e25Snbeams r_V[i] = 0.0; 6547d8d0e25Snbeams r_t[i] = 0.0; 6557d8d0e25Snbeams } 6567d8d0e25Snbeams if (!transpose) { 6577d8d0e25Snbeams readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U); 658e9132427Snbeams ContractX3d(slice, tidx, tidy, tidz, r_U, s_G, r_V); 659e9132427Snbeams ContractY3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 660e9132427Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 6617d8d0e25Snbeams dim = 0; 6627d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 663e9132427Snbeams ContractX3d(slice, tidx, tidy, tidz, r_U, s_B, r_V); 664e9132427Snbeams ContractY3d(slice, tidx, tidy, tidz, r_V, s_G, r_t); 665e9132427Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 6667d8d0e25Snbeams dim = 1; 6677d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 668e9132427Snbeams ContractX3d(slice, tidx, tidy, tidz, r_U, s_B, r_V); 669e9132427Snbeams ContractY3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 670e9132427Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_t, s_G, r_V); 6717d8d0e25Snbeams dim = 2; 6727d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 6737d8d0e25Snbeams } else { 6747d8d0e25Snbeams dim = 0; 6757d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 676e9132427Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 677e9132427Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_B, r_U); 678e9132427Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, s_G, r_V); 6797d8d0e25Snbeams dim = 1; 6807d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 681e9132427Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 682e9132427Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_G, r_U); 683e9132427Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 6847d8d0e25Snbeams add(r_V, r_t); 6857d8d0e25Snbeams dim = 2; 6867d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 687e9132427Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, s_G, r_t); 688e9132427Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_B, r_U); 689e9132427Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 6907d8d0e25Snbeams add(r_V, r_t); 6917d8d0e25Snbeams writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V); 6927d8d0e25Snbeams } 6937d8d0e25Snbeams } 6947d8d0e25Snbeams } 6957d8d0e25Snbeams 6967d8d0e25Snbeams //------------------------------------------------------------------------------ 6977d8d0e25Snbeams // 3D quadrature weights 6987d8d0e25Snbeams //------------------------------------------------------------------------------ 6997d8d0e25Snbeams __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d, 7007d8d0e25Snbeams CeedScalar *w) { 7017d8d0e25Snbeams const int i = threadIdx.x; 7027d8d0e25Snbeams const int j = threadIdx.y; 7037d8d0e25Snbeams const int k = threadIdx.z; 7047d8d0e25Snbeams const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k]; 7057d8d0e25Snbeams for (int e = blockIdx.x; e < nelem; e += gridDim.x) { 7067d8d0e25Snbeams const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D; 7077d8d0e25Snbeams w[ind] = weight; 7087d8d0e25Snbeams } 7097d8d0e25Snbeams } 7107d8d0e25Snbeams 7117d8d0e25Snbeams 7127d8d0e25Snbeams //------------------------------------------------------------------------------ 7137d8d0e25Snbeams // Basis kernels 7147d8d0e25Snbeams //------------------------------------------------------------------------------ 7157d8d0e25Snbeams 7167d8d0e25Snbeams //------------------------------------------------------------------------------ 7177d8d0e25Snbeams // Interp kernel by dim 7187d8d0e25Snbeams //------------------------------------------------------------------------------ 7199e31c45bSnbeams extern "C" __launch_bounds__(INTERP_BLKSIZE) __global__ void interp( 7209e31c45bSnbeams const CeedInt nelem, const int transpose, 7219dd88646Snbeams CeedScalar *d_interp1d, 7227d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 7237d8d0e25Snbeams CeedScalar *__restrict__ d_V) { 7249dd88646Snbeams 7257d8d0e25Snbeams HIP_DYNAMIC_SHARED( double, slice) 7269dd88646Snbeams // load interp1d into shared memory 727e9132427Snbeams __shared__ double s_B[P1D*Q1D]; 728e9132427Snbeams loadMatrix(d_interp1d, s_B); 72955250509Snbeams __syncthreads(); 7309dd88646Snbeams 7317d8d0e25Snbeams if (BASIS_DIM == 1) { 732e9132427Snbeams interp1d(nelem, transpose, s_B, d_U, d_V, slice); 7337d8d0e25Snbeams } else if (BASIS_DIM == 2) { 734e9132427Snbeams interp2d(nelem, transpose, s_B, d_U, d_V, slice); 7357d8d0e25Snbeams } else if (BASIS_DIM == 3) { 736e9132427Snbeams interp3d(nelem, transpose, s_B, d_U, d_V, slice); 7377d8d0e25Snbeams } 7387d8d0e25Snbeams } 7397d8d0e25Snbeams 7407d8d0e25Snbeams //------------------------------------------------------------------------------ 7417d8d0e25Snbeams // Grad kernel by dim 7427d8d0e25Snbeams //------------------------------------------------------------------------------ 7439e31c45bSnbeams extern "C" __launch_bounds__(GRAD_BLKSIZE) __global__ void grad(const CeedInt nelem, 7449e31c45bSnbeams const int transpose, 7459dd88646Snbeams CeedScalar *d_interp1d, CeedScalar *d_grad1d, 7467d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 7477d8d0e25Snbeams CeedScalar *__restrict__ d_V) { 7487d8d0e25Snbeams HIP_DYNAMIC_SHARED( double, slice) 7499dd88646Snbeams // load interp1d and grad1d into shared memory 750e9132427Snbeams __shared__ double s_B[P1D*Q1D]; 751e9132427Snbeams loadMatrix(d_interp1d, s_B); 752e9132427Snbeams __shared__ double s_G[P1D*Q1D]; 753e9132427Snbeams loadMatrix(d_grad1d, s_G); 75455250509Snbeams __syncthreads(); 7559dd88646Snbeams 7567d8d0e25Snbeams if (BASIS_DIM == 1) { 757e9132427Snbeams grad1d(nelem, transpose, s_B, s_G, d_U, d_V, slice); 7587d8d0e25Snbeams } else if (BASIS_DIM == 2) { 759e9132427Snbeams grad2d(nelem, transpose, s_B, s_G, d_U, d_V, slice); 7607d8d0e25Snbeams } else if (BASIS_DIM == 3) { 761e9132427Snbeams grad3d(nelem, transpose, s_B, s_G, d_U, d_V, slice); 7627d8d0e25Snbeams } 7637d8d0e25Snbeams } 7647d8d0e25Snbeams 7657d8d0e25Snbeams //------------------------------------------------------------------------------ 7667d8d0e25Snbeams // Weight kernels by dim 7677d8d0e25Snbeams //------------------------------------------------------------------------------ 7689e31c45bSnbeams extern "C" __launch_bounds__(WEIGHT_BLKSIZE) __global__ void weight(const CeedInt nelem, 7697d8d0e25Snbeams const CeedScalar *__restrict__ qweight1d, 7707d8d0e25Snbeams CeedScalar *__restrict__ v) { 7717d8d0e25Snbeams if (BASIS_DIM == 1) { 7727d8d0e25Snbeams weight1d(nelem, qweight1d, v); 7737d8d0e25Snbeams } else if (BASIS_DIM == 2) { 7747d8d0e25Snbeams weight2d(nelem, qweight1d, v); 7757d8d0e25Snbeams } else if (BASIS_DIM == 3) { 7767d8d0e25Snbeams weight3d(nelem, qweight1d, v); 7777d8d0e25Snbeams } 7787d8d0e25Snbeams } 7797d8d0e25Snbeams 7807d8d0e25Snbeams ); 7817d8d0e25Snbeams // *INDENT-ON* 7827d8d0e25Snbeams 7837d8d0e25Snbeams //------------------------------------------------------------------------------ 7849e31c45bSnbeams // Compute a block size based on required minimum threads 7859e31c45bSnbeams //------------------------------------------------------------------------------ 7869e31c45bSnbeams static CeedInt ComputeBlockSizeFromRequirement(const CeedInt required) { 7879e31c45bSnbeams CeedInt maxSize = 1024; // Max total threads per block 7889e31c45bSnbeams CeedInt currentSize = 64; // Start with one group 7899e31c45bSnbeams 7909e31c45bSnbeams while(currentSize < maxSize) { 7919e31c45bSnbeams if (currentSize > required) 7929e31c45bSnbeams break; 7939e31c45bSnbeams else 7949e31c45bSnbeams currentSize = currentSize * 2; 7959e31c45bSnbeams } 7969e31c45bSnbeams return currentSize; 7979e31c45bSnbeams } 7989e31c45bSnbeams 7999e31c45bSnbeams //------------------------------------------------------------------------------ 8009e31c45bSnbeams // Compute required thread block sizes for basis kernels given P, Q, dim, and 8019e31c45bSnbeams // ncomp 8029e31c45bSnbeams //------------------------------------------------------------------------------ 8039e31c45bSnbeams static int ComputeBasisThreadBlockSizes(const CeedInt dim, const CeedInt P1d, 8049e31c45bSnbeams const CeedInt Q1d, 8059e31c45bSnbeams const CeedInt ncomp, CeedInt *blksizes) { 8069e31c45bSnbeams 8079e31c45bSnbeams // Note that this will use the same block sizes for all dimensions when compiling, 8089e31c45bSnbeams // but as each basis object is defined for a particular dimension, we will never 8099e31c45bSnbeams // call any kernels except the ones for the dimension for which we have computed the 8109e31c45bSnbeams // block sizes. 8119e31c45bSnbeams const CeedInt thread1d = CeedIntMax(P1d, Q1d); 8129e31c45bSnbeams switch (dim) { 8139e31c45bSnbeams case 1: { 8149e31c45bSnbeams // Interp kernels: 8159e31c45bSnbeams blksizes[0] = 256; 8169e31c45bSnbeams 8179e31c45bSnbeams // Grad kernels: 8189e31c45bSnbeams blksizes[1] = 256; 8199e31c45bSnbeams 8209e31c45bSnbeams // Weight kernels: 8219e31c45bSnbeams blksizes[2] = 256; 8229e31c45bSnbeams 8239e31c45bSnbeams } break; 8249e31c45bSnbeams case 2: { 8259e31c45bSnbeams // Interp kernels: 8269e31c45bSnbeams CeedInt required = thread1d * thread1d * ncomp; 8279e31c45bSnbeams blksizes[0] = ComputeBlockSizeFromRequirement(required); 8289e31c45bSnbeams 8299e31c45bSnbeams // Grad kernels: currently use same required minimum threads 8309e31c45bSnbeams blksizes[1] = ComputeBlockSizeFromRequirement(required); 8319e31c45bSnbeams 8329e31c45bSnbeams // Weight kernels: 8339e31c45bSnbeams required = CeedIntMax(64, Q1d * Q1d); 8349e31c45bSnbeams blksizes[2] = ComputeBlockSizeFromRequirement(required); 8359e31c45bSnbeams 8369e31c45bSnbeams } break; 8379e31c45bSnbeams case 3: { 8389e31c45bSnbeams // Interp kernels: 8399e31c45bSnbeams CeedInt required = thread1d * thread1d * ncomp; 8409e31c45bSnbeams blksizes[0] = ComputeBlockSizeFromRequirement(required); 8419e31c45bSnbeams 8429e31c45bSnbeams // Grad kernels: currently use same required minimum threads 8439e31c45bSnbeams blksizes[1] = ComputeBlockSizeFromRequirement(required); 8449e31c45bSnbeams 8459e31c45bSnbeams // Weight kernels: 8469e31c45bSnbeams required = Q1d * Q1d * Q1d; 8479e31c45bSnbeams blksizes[2] = ComputeBlockSizeFromRequirement(required); 8489e31c45bSnbeams } 8499e31c45bSnbeams } 8509e31c45bSnbeams 851*e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 8529e31c45bSnbeams } 8539e31c45bSnbeams 8549e31c45bSnbeams //------------------------------------------------------------------------------ 8557d8d0e25Snbeams // Apply basis 8567d8d0e25Snbeams //------------------------------------------------------------------------------ 8577d8d0e25Snbeams int CeedBasisApplyTensor_Hip_shared(CeedBasis basis, const CeedInt nelem, 8587d8d0e25Snbeams CeedTransposeMode tmode, 8597d8d0e25Snbeams CeedEvalMode emode, CeedVector u, 8607d8d0e25Snbeams CeedVector v) { 8617d8d0e25Snbeams int ierr; 8627d8d0e25Snbeams Ceed ceed; 863*e15f9bd0SJeremy L Thompson ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr); 8647d8d0e25Snbeams Ceed_Hip_shared *ceed_Hip; 865*e15f9bd0SJeremy L Thompson CeedGetData(ceed, &ceed_Hip); CeedChkBackend(ierr); 8667d8d0e25Snbeams CeedBasis_Hip_shared *data; 867*e15f9bd0SJeremy L Thompson CeedBasisGetData(basis, &data); CeedChkBackend(ierr); 8687d8d0e25Snbeams const CeedInt transpose = tmode == CEED_TRANSPOSE; 8697d8d0e25Snbeams CeedInt dim, ncomp; 870*e15f9bd0SJeremy L Thompson ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr); 871*e15f9bd0SJeremy L Thompson ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr); 8727d8d0e25Snbeams 8737d8d0e25Snbeams // Read vectors 8747d8d0e25Snbeams const CeedScalar *d_u; 8757d8d0e25Snbeams CeedScalar *d_v; 8767d8d0e25Snbeams if (emode != CEED_EVAL_WEIGHT) { 877*e15f9bd0SJeremy L Thompson ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChkBackend(ierr); 8787d8d0e25Snbeams } 879*e15f9bd0SJeremy L Thompson ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr); 8807d8d0e25Snbeams 8817d8d0e25Snbeams // Clear v for transpose mode 8827d8d0e25Snbeams if (tmode == CEED_TRANSPOSE) { 8837d8d0e25Snbeams CeedInt length; 884*e15f9bd0SJeremy L Thompson ierr = CeedVectorGetLength(v, &length); CeedChkBackend(ierr); 885*e15f9bd0SJeremy L Thompson ierr = hipMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChkBackend(ierr); 8867d8d0e25Snbeams } 8877d8d0e25Snbeams 8887d8d0e25Snbeams // Apply basis operation 8897d8d0e25Snbeams switch (emode) { 8907d8d0e25Snbeams case CEED_EVAL_INTERP: { 8917d8d0e25Snbeams CeedInt P1d, Q1d; 8929e31c45bSnbeams CeedInt blksize = data->blksizes[0]; 893*e15f9bd0SJeremy L Thompson ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr); 894*e15f9bd0SJeremy L Thompson ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr); 8957d8d0e25Snbeams CeedInt thread1d = CeedIntMax(Q1d, P1d); 8969dd88646Snbeams void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->d_interp1d, 8977d8d0e25Snbeams &d_u, &d_v 8987d8d0e25Snbeams }; 8997d8d0e25Snbeams if (dim == 1) { 900e7ea6884Snbeams CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64; 9017d8d0e25Snbeams elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1; 9027d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 9037d8d0e25Snbeams ? 1 : 0 ); 9047d8d0e25Snbeams CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar); 9057d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, 1, 9067d8d0e25Snbeams elemsPerBlock, sharedMem, 907*e15f9bd0SJeremy L Thompson interpargs); CeedChkBackend(ierr); 9087d8d0e25Snbeams } else if (dim == 2) { 9099e31c45bSnbeams // Check if required threads is small enough to do multiple elems 9109e31c45bSnbeams const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1); 9117d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 9127d8d0e25Snbeams ? 1 : 0 ); 9137d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 9147d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d, 9157d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 916*e15f9bd0SJeremy L Thompson interpargs); CeedChkBackend(ierr); 9177d8d0e25Snbeams } else if (dim == 3) { 9187d8d0e25Snbeams CeedInt elemsPerBlock = 1; 9197d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 9207d8d0e25Snbeams ? 1 : 0 ); 9217d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 9227d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d, 9237d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 924*e15f9bd0SJeremy L Thompson interpargs); CeedChkBackend(ierr); 9257d8d0e25Snbeams } 9267d8d0e25Snbeams } break; 9277d8d0e25Snbeams case CEED_EVAL_GRAD: { 9287d8d0e25Snbeams CeedInt P1d, Q1d; 9299e31c45bSnbeams CeedInt blksize = data->blksizes[1]; 930*e15f9bd0SJeremy L Thompson ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr); 931*e15f9bd0SJeremy L Thompson ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr); 9327d8d0e25Snbeams CeedInt thread1d = CeedIntMax(Q1d, P1d); 9339dd88646Snbeams void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->d_interp1d, 9349dd88646Snbeams &data->d_grad1d, &d_u, &d_v 9357d8d0e25Snbeams }; 9367d8d0e25Snbeams if (dim == 1) { 937e7ea6884Snbeams CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64; 9387d8d0e25Snbeams elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1; 9397d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 9407d8d0e25Snbeams ? 1 : 0 ); 9417d8d0e25Snbeams CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar); 9427d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, 1, 9437d8d0e25Snbeams elemsPerBlock, sharedMem, gradargs); 944*e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 9457d8d0e25Snbeams } else if (dim == 2) { 9469e31c45bSnbeams // Check if required threads is small enough to do multiple elems 9479e31c45bSnbeams const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1); 9487d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 9497d8d0e25Snbeams ? 1 : 0 ); 9507d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 9517d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d, 9527d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 953*e15f9bd0SJeremy L Thompson gradargs); CeedChkBackend(ierr); 9547d8d0e25Snbeams } else if (dim == 3) { 9557d8d0e25Snbeams CeedInt elemsPerBlock = 1; 9567d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 9577d8d0e25Snbeams ? 1 : 0 ); 9587d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 9597d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d, 9607d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 961*e15f9bd0SJeremy L Thompson gradargs); CeedChkBackend(ierr); 9627d8d0e25Snbeams } 9637d8d0e25Snbeams } break; 9647d8d0e25Snbeams case CEED_EVAL_WEIGHT: { 9657d8d0e25Snbeams CeedInt Q1d; 9669e31c45bSnbeams CeedInt blksize = data->blksizes[2]; 967*e15f9bd0SJeremy L Thompson ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr); 9687d8d0e25Snbeams void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v}; 9697d8d0e25Snbeams if (dim == 1) { 9709e31c45bSnbeams const CeedInt optElems = blksize/Q1d; 9717d8d0e25Snbeams const CeedInt elemsPerBlock = optElems>0?optElems:1; 9727d8d0e25Snbeams const CeedInt gridsize = nelem/elemsPerBlock + ( ( 9737d8d0e25Snbeams nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 9747d8d0e25Snbeams ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, 9757d8d0e25Snbeams elemsPerBlock, 1, weightargs); 976*e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 9777d8d0e25Snbeams } else if (dim == 2) { 9789e31c45bSnbeams const CeedInt optElems = blksize/(Q1d*Q1d); 9797d8d0e25Snbeams const CeedInt elemsPerBlock = optElems>0?optElems:1; 9807d8d0e25Snbeams const CeedInt gridsize = nelem/elemsPerBlock + ( ( 9817d8d0e25Snbeams nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 9827d8d0e25Snbeams ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, 9837d8d0e25Snbeams elemsPerBlock, weightargs); 984*e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 9857d8d0e25Snbeams } else if (dim == 3) { 9867d8d0e25Snbeams const CeedInt gridsize = nelem; 9877d8d0e25Snbeams ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, Q1d, 9887d8d0e25Snbeams weightargs); 989*e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 9907d8d0e25Snbeams } 9917d8d0e25Snbeams } break; 9927d8d0e25Snbeams // LCOV_EXCL_START 9937d8d0e25Snbeams // Evaluate the divergence to/from the quadrature points 9947d8d0e25Snbeams case CEED_EVAL_DIV: 995*e15f9bd0SJeremy L Thompson return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported"); 9967d8d0e25Snbeams // Evaluate the curl to/from the quadrature points 9977d8d0e25Snbeams case CEED_EVAL_CURL: 998*e15f9bd0SJeremy L Thompson return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported"); 9997d8d0e25Snbeams // Take no action, BasisApply should not have been called 10007d8d0e25Snbeams case CEED_EVAL_NONE: 1001*e15f9bd0SJeremy L Thompson return CeedError(ceed, CEED_ERROR_BACKEND, 10027d8d0e25Snbeams "CEED_EVAL_NONE does not make sense in this context"); 10037d8d0e25Snbeams // LCOV_EXCL_STOP 10047d8d0e25Snbeams } 10057d8d0e25Snbeams 10067d8d0e25Snbeams // Restore vectors 10077d8d0e25Snbeams if (emode != CEED_EVAL_WEIGHT) { 1008*e15f9bd0SJeremy L Thompson ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChkBackend(ierr); 10097d8d0e25Snbeams } 1010*e15f9bd0SJeremy L Thompson ierr = CeedVectorRestoreArray(v, &d_v); CeedChkBackend(ierr); 1011*e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 10127d8d0e25Snbeams } 10137d8d0e25Snbeams 10147d8d0e25Snbeams //------------------------------------------------------------------------------ 10157d8d0e25Snbeams // Destroy basis 10167d8d0e25Snbeams //------------------------------------------------------------------------------ 10177d8d0e25Snbeams static int CeedBasisDestroy_Hip_shared(CeedBasis basis) { 10187d8d0e25Snbeams int ierr; 10197d8d0e25Snbeams Ceed ceed; 1020*e15f9bd0SJeremy L Thompson ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr); 10217d8d0e25Snbeams 10227d8d0e25Snbeams CeedBasis_Hip_shared *data; 1023*e15f9bd0SJeremy L Thompson ierr = CeedBasisGetData(basis, &data); CeedChkBackend(ierr); 10247d8d0e25Snbeams 10257d8d0e25Snbeams CeedChk_Hip(ceed, hipModuleUnload(data->module)); 10267d8d0e25Snbeams 10277d8d0e25Snbeams ierr = hipFree(data->d_qweight1d); CeedChk_Hip(ceed, ierr); 10287d8d0e25Snbeams ierr = hipFree(data->d_interp1d); CeedChk_Hip(ceed, ierr); 10297d8d0e25Snbeams ierr = hipFree(data->d_grad1d); CeedChk_Hip(ceed, ierr); 10307d8d0e25Snbeams ierr = hipFree(data->d_collograd1d); CeedChk_Hip(ceed, ierr); 10317d8d0e25Snbeams 1032*e15f9bd0SJeremy L Thompson ierr = CeedFree(&data); CeedChkBackend(ierr); 10337d8d0e25Snbeams 1034*e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 10357d8d0e25Snbeams } 10367d8d0e25Snbeams 10377d8d0e25Snbeams //------------------------------------------------------------------------------ 10387d8d0e25Snbeams // Create tensor basis 10397d8d0e25Snbeams //------------------------------------------------------------------------------ 10407d8d0e25Snbeams int CeedBasisCreateTensorH1_Hip_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d, 10417d8d0e25Snbeams const CeedScalar *interp1d, 10427d8d0e25Snbeams const CeedScalar *grad1d, 10437d8d0e25Snbeams const CeedScalar *qref1d, 10447d8d0e25Snbeams const CeedScalar *qweight1d, 10457d8d0e25Snbeams CeedBasis basis) { 10467d8d0e25Snbeams int ierr; 10477d8d0e25Snbeams Ceed ceed; 1048*e15f9bd0SJeremy L Thompson ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr); 10497d8d0e25Snbeams CeedBasis_Hip_shared *data; 1050*e15f9bd0SJeremy L Thompson ierr = CeedCalloc(1, &data); CeedChkBackend(ierr); 10517d8d0e25Snbeams 10527d8d0e25Snbeams // Copy basis data to GPU 10537d8d0e25Snbeams const CeedInt qBytes = Q1d * sizeof(CeedScalar); 10547d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Hip(ceed, ierr); 10557d8d0e25Snbeams ierr = hipMemcpy(data->d_qweight1d, qweight1d, qBytes, 10567d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 10577d8d0e25Snbeams 10587d8d0e25Snbeams const CeedInt iBytes = qBytes * P1d; 10597d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Hip(ceed, ierr); 10607d8d0e25Snbeams ierr = hipMemcpy(data->d_interp1d, interp1d, iBytes, 10617d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 10627d8d0e25Snbeams 10637d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Hip(ceed, ierr); 10647d8d0e25Snbeams ierr = hipMemcpy(data->d_grad1d, grad1d, iBytes, 10657d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 10667d8d0e25Snbeams 10677d8d0e25Snbeams // Compute collocated gradient and copy to GPU 10687d8d0e25Snbeams data->d_collograd1d = NULL; 10697d8d0e25Snbeams if (dim == 3 && Q1d >= P1d) { 10707d8d0e25Snbeams CeedScalar *collograd1d; 1071*e15f9bd0SJeremy L Thompson ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChkBackend(ierr); 1072*e15f9bd0SJeremy L Thompson ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChkBackend(ierr); 10737d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_collograd1d, qBytes * Q1d); 10747d8d0e25Snbeams CeedChk_Hip(ceed, ierr); 10757d8d0e25Snbeams ierr = hipMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d, 10767d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 1077*e15f9bd0SJeremy L Thompson ierr = CeedFree(&collograd1d); CeedChkBackend(ierr); 10787d8d0e25Snbeams } 10797d8d0e25Snbeams 10809e31c45bSnbeams // Set number of threads per block for basis kernels 10817d8d0e25Snbeams CeedInt ncomp; 1082*e15f9bd0SJeremy L Thompson ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr); 10839e31c45bSnbeams ierr = ComputeBasisThreadBlockSizes(dim, P1d, Q1d, ncomp, data->blksizes); 1084*e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 10859e31c45bSnbeams 10869e31c45bSnbeams // Compile basis kernels 10879e31c45bSnbeams ierr = CeedCompileHip(ceed, kernelsShared, &data->module, 11, 10887d8d0e25Snbeams "Q1D", Q1d, 10897d8d0e25Snbeams "P1D", P1d, 10907d8d0e25Snbeams "T1D", CeedIntMax(Q1d, P1d), 10917d8d0e25Snbeams "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ? 10927d8d0e25Snbeams Q1d : P1d, dim), 10937d8d0e25Snbeams "BASIS_DIM", dim, 10947d8d0e25Snbeams "BASIS_NCOMP", ncomp, 10957d8d0e25Snbeams "BASIS_ELEMSIZE", CeedIntPow(P1d, dim), 10969e31c45bSnbeams "BASIS_NQPT", CeedIntPow(Q1d, dim), 10979e31c45bSnbeams "INTERP_BLKSIZE", data->blksizes[0], 10989e31c45bSnbeams "GRAD_BLKSIZE", data->blksizes[1], 10999e31c45bSnbeams "WEIGHT_BLKSIZE", data->blksizes[2] 1100*e15f9bd0SJeremy L Thompson ); CeedChkBackend(ierr); 11017d8d0e25Snbeams ierr = CeedGetKernelHip(ceed, data->module, "interp", &data->interp); 1102*e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 11037d8d0e25Snbeams ierr = CeedGetKernelHip(ceed, data->module, "grad", &data->grad); 1104*e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 11057d8d0e25Snbeams ierr = CeedGetKernelHip(ceed, data->module, "weight", &data->weight); 1106*e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 11077d8d0e25Snbeams 1108*e15f9bd0SJeremy L Thompson ierr = CeedBasisSetData(basis, data); CeedChkBackend(ierr); 11097d8d0e25Snbeams 11107d8d0e25Snbeams // Register backend functions 11117d8d0e25Snbeams ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply", 11127d8d0e25Snbeams CeedBasisApplyTensor_Hip_shared); 1113*e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 11147d8d0e25Snbeams ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", 1115*e15f9bd0SJeremy L Thompson CeedBasisDestroy_Hip_shared); CeedChkBackend(ierr); 1116*e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 11177d8d0e25Snbeams } 11187d8d0e25Snbeams //------------------------------------------------------------------------------ 1119