17d8d0e25Snbeams // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC. 27d8d0e25Snbeams // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707. 37d8d0e25Snbeams // All Rights reserved. See files LICENSE and NOTICE for details. 47d8d0e25Snbeams // 57d8d0e25Snbeams // This file is part of CEED, a collection of benchmarks, miniapps, software 67d8d0e25Snbeams // libraries and APIs for efficient high-order finite element and spectral 77d8d0e25Snbeams // element discretizations for exascale applications. For more information and 87d8d0e25Snbeams // source code availability see http://github.com/ceed. 97d8d0e25Snbeams // 107d8d0e25Snbeams // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 117d8d0e25Snbeams // a collaborative effort of two U.S. Department of Energy organizations (Office 127d8d0e25Snbeams // of Science and the National Nuclear Security Administration) responsible for 137d8d0e25Snbeams // the planning and preparation of a capable exascale ecosystem, including 147d8d0e25Snbeams // software, applications, hardware, advanced system engineering and early 157d8d0e25Snbeams // testbed platforms, in support of the nation's exascale computing imperative. 167d8d0e25Snbeams 17ec3da8bcSJed Brown #include <ceed/ceed.h> 18ec3da8bcSJed Brown #include <ceed/backend.h> 193d576824SJeremy L Thompson #include <hip/hip_runtime.h> 203d576824SJeremy L Thompson #include <stddef.h> 217d8d0e25Snbeams #include "ceed-hip-shared.h" 223d576824SJeremy L Thompson #include "../hip/ceed-hip.h" 237d8d0e25Snbeams #include "../hip/ceed-hip-compile.h" 247d8d0e25Snbeams 257d8d0e25Snbeams //------------------------------------------------------------------------------ 267d8d0e25Snbeams // Shared mem kernels 277d8d0e25Snbeams //------------------------------------------------------------------------------ 287d8d0e25Snbeams // *INDENT-OFF* 297d8d0e25Snbeams static const char *kernelsShared = QUOTE( 307d8d0e25Snbeams 317d8d0e25Snbeams //------------------------------------------------------------------------------ 327d8d0e25Snbeams // Sum input into output 337d8d0e25Snbeams //------------------------------------------------------------------------------ 347d8d0e25Snbeams inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) { 357d8d0e25Snbeams for (int i = 0; i < P1D; i++) 367d8d0e25Snbeams r_V[i] += r_U[i]; 377d8d0e25Snbeams } 387d8d0e25Snbeams 397d8d0e25Snbeams //------------------------------------------------------------------------------ 409dd88646Snbeams // Load matrices for basis actions 419dd88646Snbeams //------------------------------------------------------------------------------ 429dd88646Snbeams inline __device__ void loadMatrix(const CeedScalar* d_B, CeedScalar* B) { 439dd88646Snbeams CeedInt tid = threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.y*blockDim.x; 449dd88646Snbeams for (CeedInt i = tid; i < P1D*Q1D; i += blockDim.x*blockDim.y*blockDim.z) 459dd88646Snbeams B[i] = d_B[i]; 469dd88646Snbeams } 479dd88646Snbeams 489dd88646Snbeams //------------------------------------------------------------------------------ 497d8d0e25Snbeams // 1D 507d8d0e25Snbeams //------------------------------------------------------------------------------ 517d8d0e25Snbeams 527d8d0e25Snbeams //------------------------------------------------------------------------------ 537d8d0e25Snbeams // Read DoFs 547d8d0e25Snbeams //------------------------------------------------------------------------------ 557d8d0e25Snbeams inline __device__ void readDofs1d(const int elem, const int tidx, 567d8d0e25Snbeams const int tidy, const int tidz,const int comp, 577d8d0e25Snbeams const int nelem, const CeedScalar *d_U, 587d8d0e25Snbeams CeedScalar *slice) { 597d8d0e25Snbeams for (int i = 0; i < P1D; i++) 607d8d0e25Snbeams slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem]; 617d8d0e25Snbeams for (int i = P1D; i < Q1D; i++) 627d8d0e25Snbeams slice[i + tidz*T1D] = 0.0; 637d8d0e25Snbeams } 647d8d0e25Snbeams 657d8d0e25Snbeams //------------------------------------------------------------------------------ 667d8d0e25Snbeams // Write DoFs 677d8d0e25Snbeams //------------------------------------------------------------------------------ 687d8d0e25Snbeams inline __device__ void writeDofs1d(const int elem, const int tidx, 697d8d0e25Snbeams const int tidy, const int comp, 707d8d0e25Snbeams const int nelem, const CeedScalar &r_V, 717d8d0e25Snbeams CeedScalar *d_V) { 727d8d0e25Snbeams if (tidx<P1D) 737d8d0e25Snbeams d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V; 747d8d0e25Snbeams } 757d8d0e25Snbeams 767d8d0e25Snbeams //------------------------------------------------------------------------------ 777d8d0e25Snbeams // Read quadrature point data 787d8d0e25Snbeams //------------------------------------------------------------------------------ 797d8d0e25Snbeams inline __device__ void readQuads1d(const int elem, const int tidx, 807d8d0e25Snbeams const int tidy, const int tidz, const int comp, 817d8d0e25Snbeams const int dim, const int nelem, 827d8d0e25Snbeams const CeedScalar *d_U, CeedScalar *slice) { 837d8d0e25Snbeams for (int i = 0; i < Q1D; i++) 847d8d0e25Snbeams slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem + 857d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D]; 867d8d0e25Snbeams for (int i = Q1D; i < P1D; i++) 877d8d0e25Snbeams slice[i + tidz*T1D] = 0.0; 887d8d0e25Snbeams } 897d8d0e25Snbeams 907d8d0e25Snbeams //------------------------------------------------------------------------------ 917d8d0e25Snbeams // Write quadrature point data 927d8d0e25Snbeams //------------------------------------------------------------------------------ 937d8d0e25Snbeams inline __device__ void writeQuads1d(const int elem, const int tidx, 947d8d0e25Snbeams const int tidy, const int comp, 957d8d0e25Snbeams const int dim, const int nelem, 967d8d0e25Snbeams const CeedScalar &r_V, CeedScalar *d_V) { 977d8d0e25Snbeams if (tidx<Q1D) 987d8d0e25Snbeams d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V; 997d8d0e25Snbeams } 1007d8d0e25Snbeams 1017d8d0e25Snbeams //------------------------------------------------------------------------------ 1027d8d0e25Snbeams // 1D tensor contraction 1037d8d0e25Snbeams //------------------------------------------------------------------------------ 1047d8d0e25Snbeams inline __device__ void ContractX1d(CeedScalar *slice, const int tidx, 1057d8d0e25Snbeams const int tidy, const int tidz, 1067d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, 1077d8d0e25Snbeams CeedScalar &V) { 1087d8d0e25Snbeams V = 0.0; 1097d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 1107d8d0e25Snbeams V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction 1117d8d0e25Snbeams } 1127d8d0e25Snbeams 1137d8d0e25Snbeams //------------------------------------------------------------------------------ 1147d8d0e25Snbeams // 1D transpose tensor contraction 1157d8d0e25Snbeams //------------------------------------------------------------------------------ 1167d8d0e25Snbeams inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx, 1177d8d0e25Snbeams const int tidy, const int tidz, 1187d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 1197d8d0e25Snbeams V = 0.0; 1207d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 1217d8d0e25Snbeams V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction 1227d8d0e25Snbeams } 1237d8d0e25Snbeams 1247d8d0e25Snbeams //------------------------------------------------------------------------------ 1257d8d0e25Snbeams // 1D interpolate to quadrature points 1267d8d0e25Snbeams //------------------------------------------------------------------------------ 1277d8d0e25Snbeams inline __device__ void interp1d(const CeedInt nelem, const int transpose, 128e9132427Snbeams const CeedScalar *s_B, 1297d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 1307d8d0e25Snbeams CeedScalar *__restrict__ d_V, 1317d8d0e25Snbeams CeedScalar *slice) { 1327d8d0e25Snbeams CeedScalar r_V; 1337d8d0e25Snbeams CeedScalar r_t; 1347d8d0e25Snbeams 1357d8d0e25Snbeams const int tidx = threadIdx.x; 1367d8d0e25Snbeams const int tidy = threadIdx.y; 1377d8d0e25Snbeams const int tidz = threadIdx.z; 1387d8d0e25Snbeams 1397d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 1407d8d0e25Snbeams elem += gridDim.x*blockDim.z) { 1417d8d0e25Snbeams for (int comp = 0; comp < BASIS_NCOMP; comp++) { 1427d8d0e25Snbeams if (!transpose) { 1437d8d0e25Snbeams readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 144e9132427Snbeams ContractX1d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 1457d8d0e25Snbeams writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 1467d8d0e25Snbeams } else { 1477d8d0e25Snbeams readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice); 148e9132427Snbeams ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 1497d8d0e25Snbeams writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 1507d8d0e25Snbeams } 1517d8d0e25Snbeams } 1527d8d0e25Snbeams } 1537d8d0e25Snbeams } 1547d8d0e25Snbeams 1557d8d0e25Snbeams //------------------------------------------------------------------------------ 1567d8d0e25Snbeams // 1D derivatives at quadrature points 1577d8d0e25Snbeams //------------------------------------------------------------------------------ 1587d8d0e25Snbeams inline __device__ void grad1d(const CeedInt nelem, const int transpose, 159e9132427Snbeams const CeedScalar *s_B, const CeedScalar *s_G, 1607d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 1617d8d0e25Snbeams CeedScalar *__restrict__ d_V, 1627d8d0e25Snbeams CeedScalar *slice) { 1637d8d0e25Snbeams CeedScalar r_U; 1647d8d0e25Snbeams CeedScalar r_V; 1657d8d0e25Snbeams 1667d8d0e25Snbeams const int tidx = threadIdx.x; 1677d8d0e25Snbeams const int tidy = threadIdx.y; 1687d8d0e25Snbeams const int tidz = threadIdx.z; 1697d8d0e25Snbeams int dim; 1707d8d0e25Snbeams 1717d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 1727d8d0e25Snbeams elem += gridDim.x*blockDim.z) { 1737d8d0e25Snbeams for(int comp = 0; comp < BASIS_NCOMP; comp++) { 1747d8d0e25Snbeams if (!transpose) { 1757d8d0e25Snbeams readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 176e9132427Snbeams ContractX1d(slice, tidx, tidy, tidz, r_U, s_G, r_V); 1777d8d0e25Snbeams dim = 0; 1787d8d0e25Snbeams writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 1797d8d0e25Snbeams } else { 1807d8d0e25Snbeams dim = 0; 1817d8d0e25Snbeams readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice); 182e9132427Snbeams ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, s_G, r_V); 1837d8d0e25Snbeams writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 1847d8d0e25Snbeams } 1857d8d0e25Snbeams } 1867d8d0e25Snbeams } 1877d8d0e25Snbeams } 1887d8d0e25Snbeams 1897d8d0e25Snbeams //------------------------------------------------------------------------------ 1907d8d0e25Snbeams // 1D Quadrature weights 1917d8d0e25Snbeams //------------------------------------------------------------------------------ 1927d8d0e25Snbeams __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d, 1937d8d0e25Snbeams CeedScalar *w) { 1947d8d0e25Snbeams const int tid = threadIdx.x; 1957d8d0e25Snbeams const CeedScalar weight = qweight1d[tid]; 1967d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem; 1977d8d0e25Snbeams elem += gridDim.x*blockDim.y) { 1987d8d0e25Snbeams const int ind = elem*Q1D + tid; 1997d8d0e25Snbeams w[ind] = weight; 2007d8d0e25Snbeams } 2017d8d0e25Snbeams } 2027d8d0e25Snbeams 2037d8d0e25Snbeams //------------------------------------------------------------------------------ 2047d8d0e25Snbeams // 2D 2057d8d0e25Snbeams //------------------------------------------------------------------------------ 2067d8d0e25Snbeams 2077d8d0e25Snbeams //------------------------------------------------------------------------------ 2087d8d0e25Snbeams // Read DoFs 2097d8d0e25Snbeams //------------------------------------------------------------------------------ 2107d8d0e25Snbeams inline __device__ void readDofs2d(const int elem, const int tidx, 2117d8d0e25Snbeams const int tidy, const int comp, 2127d8d0e25Snbeams const int nelem, const CeedScalar *d_U, 2137d8d0e25Snbeams CeedScalar &U) { 2147d8d0e25Snbeams U = (tidx<P1D && tidy<P1D) ? 2157d8d0e25Snbeams d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0; 2167d8d0e25Snbeams } 2177d8d0e25Snbeams 2187d8d0e25Snbeams //------------------------------------------------------------------------------ 2197d8d0e25Snbeams // Write DoFs 2207d8d0e25Snbeams //------------------------------------------------------------------------------ 2217d8d0e25Snbeams inline __device__ void writeDofs2d(const int elem, const int tidx, 2227d8d0e25Snbeams const int tidy, const int comp, 2237d8d0e25Snbeams const int nelem, const CeedScalar &r_V, 2247d8d0e25Snbeams CeedScalar *d_V) { 2257d8d0e25Snbeams if (tidx<P1D && tidy<P1D) 2267d8d0e25Snbeams d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V; 2277d8d0e25Snbeams } 2287d8d0e25Snbeams 2297d8d0e25Snbeams //------------------------------------------------------------------------------ 2307d8d0e25Snbeams // Read quadrature point data 2317d8d0e25Snbeams //------------------------------------------------------------------------------ 2327d8d0e25Snbeams inline __device__ void readQuads2d(const int elem, const int tidx, 2337d8d0e25Snbeams const int tidy, const int comp, 2347d8d0e25Snbeams const int dim, const int nelem, 2357d8d0e25Snbeams const CeedScalar *d_U, CeedScalar &U ) { 2367d8d0e25Snbeams U = (tidx<Q1D && tidy<Q1D) ? 2377d8d0e25Snbeams d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 2387d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0; 2397d8d0e25Snbeams } 2407d8d0e25Snbeams 2417d8d0e25Snbeams //------------------------------------------------------------------------------ 2427d8d0e25Snbeams // Write quadrature point data 2437d8d0e25Snbeams //------------------------------------------------------------------------------ 2447d8d0e25Snbeams inline __device__ void writeQuads2d(const int elem, const int tidx, 2457d8d0e25Snbeams const int tidy, const int comp, 2467d8d0e25Snbeams const int dim, const int nelem, 2477d8d0e25Snbeams const CeedScalar &r_V, CeedScalar *d_V) { 2487d8d0e25Snbeams if (tidx<Q1D && tidy<Q1D) 2497d8d0e25Snbeams d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 2507d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V; 2517d8d0e25Snbeams } 2527d8d0e25Snbeams 2537d8d0e25Snbeams //------------------------------------------------------------------------------ 2547d8d0e25Snbeams // 2D tensor contraction x 2557d8d0e25Snbeams //------------------------------------------------------------------------------ 2567d8d0e25Snbeams inline __device__ void ContractX2d(CeedScalar *slice, const int tidx, 2577d8d0e25Snbeams const int tidy, const int tidz, 2587d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, 2597d8d0e25Snbeams CeedScalar &V) { 2607d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 2617d8d0e25Snbeams __syncthreads(); 2627d8d0e25Snbeams V = 0.0; 2637d8d0e25Snbeams if (tidx < Q1D) 2647d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 2657d8d0e25Snbeams V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 2667d8d0e25Snbeams __syncthreads(); 2677d8d0e25Snbeams } 2687d8d0e25Snbeams 2697d8d0e25Snbeams //------------------------------------------------------------------------------ 2707d8d0e25Snbeams // 2D tensor contraction y 2717d8d0e25Snbeams //------------------------------------------------------------------------------ 2727d8d0e25Snbeams inline __device__ void ContractY2d(CeedScalar *slice, const int tidx, 2737d8d0e25Snbeams const int tidy, const int tidz, 2747d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, 2757d8d0e25Snbeams CeedScalar &V) { 2767d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 2777d8d0e25Snbeams __syncthreads(); 2787d8d0e25Snbeams V = 0.0; 2797d8d0e25Snbeams if (tidy < Q1D) 2807d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 2817d8d0e25Snbeams V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 2827d8d0e25Snbeams __syncthreads(); 2837d8d0e25Snbeams } 2847d8d0e25Snbeams 2857d8d0e25Snbeams //------------------------------------------------------------------------------ 2867d8d0e25Snbeams // 2D transpose tensor contraction y 2877d8d0e25Snbeams //------------------------------------------------------------------------------ 2887d8d0e25Snbeams inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx, 2897d8d0e25Snbeams const int tidy, const int tidz, 2907d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 2917d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 2927d8d0e25Snbeams __syncthreads(); 2937d8d0e25Snbeams V = 0.0; 2947d8d0e25Snbeams if (tidy < P1D) 2957d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 2967d8d0e25Snbeams V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 2977d8d0e25Snbeams __syncthreads(); 2987d8d0e25Snbeams } 2997d8d0e25Snbeams 3007d8d0e25Snbeams //------------------------------------------------------------------------------ 3017d8d0e25Snbeams // 2D transpose tensor contraction x 3027d8d0e25Snbeams //------------------------------------------------------------------------------ 3037d8d0e25Snbeams inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx, 3047d8d0e25Snbeams const int tidy, const int tidz, 3057d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 3067d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 3077d8d0e25Snbeams __syncthreads(); 3087d8d0e25Snbeams V = 0.0; 3097d8d0e25Snbeams if (tidx < P1D) 3107d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 3117d8d0e25Snbeams V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 3127d8d0e25Snbeams __syncthreads(); 3137d8d0e25Snbeams } 3147d8d0e25Snbeams 3157d8d0e25Snbeams //------------------------------------------------------------------------------ 3167d8d0e25Snbeams // 2D interpolate to quadrature points 3177d8d0e25Snbeams //------------------------------------------------------------------------------ 3187d8d0e25Snbeams inline __device__ void interp2d(const CeedInt nelem, const int transpose, 319e9132427Snbeams const CeedScalar *s_B, 3207d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 3217d8d0e25Snbeams CeedScalar *__restrict__ d_V, 3227d8d0e25Snbeams CeedScalar *slice) { 3237d8d0e25Snbeams CeedScalar r_V; 3247d8d0e25Snbeams CeedScalar r_t; 3257d8d0e25Snbeams 3267d8d0e25Snbeams const int tidx = threadIdx.x; 3277d8d0e25Snbeams const int tidy = threadIdx.y; 3287d8d0e25Snbeams const int tidz = threadIdx.z; 3297d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 3307d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 3317d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 3327d8d0e25Snbeams 3337d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 3347d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 3357d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 3367d8d0e25Snbeams r_V = 0.0; 3377d8d0e25Snbeams r_t = 0.0; 3387d8d0e25Snbeams if (!transpose) { 3397d8d0e25Snbeams readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V); 340e9132427Snbeams ContractX2d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 341e9132427Snbeams ContractY2d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 3427d8d0e25Snbeams writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 3437d8d0e25Snbeams } else { 3447d8d0e25Snbeams readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 345e9132427Snbeams ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 346e9132427Snbeams ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 3477d8d0e25Snbeams writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 3487d8d0e25Snbeams } 3497d8d0e25Snbeams } 3507d8d0e25Snbeams } 3517d8d0e25Snbeams 3527d8d0e25Snbeams //------------------------------------------------------------------------------ 3537d8d0e25Snbeams // 2D derivatives at quadrature points 3547d8d0e25Snbeams //------------------------------------------------------------------------------ 3557d8d0e25Snbeams inline __device__ void grad2d(const CeedInt nelem, const int transpose, 356e9132427Snbeams const CeedScalar *s_B, const CeedScalar *s_G, 3577d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 3587d8d0e25Snbeams CeedScalar *__restrict__ d_V, CeedScalar *slice) { 3597d8d0e25Snbeams CeedScalar r_U; 3607d8d0e25Snbeams CeedScalar r_V; 3617d8d0e25Snbeams CeedScalar r_t; 3627d8d0e25Snbeams 3637d8d0e25Snbeams const int tidx = threadIdx.x; 3647d8d0e25Snbeams const int tidy = threadIdx.y; 3657d8d0e25Snbeams const int tidz = threadIdx.z; 3667d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 3677d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 3687d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 3697d8d0e25Snbeams int dim; 3707d8d0e25Snbeams 3717d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 3727d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 3737d8d0e25Snbeams if (!transpose) { 3747d8d0e25Snbeams readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U); 375e9132427Snbeams ContractX2d(slice, tidx, tidy, tidz, r_U, s_G, r_t); 376e9132427Snbeams ContractY2d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 3777d8d0e25Snbeams dim = 0; 3787d8d0e25Snbeams writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 379e9132427Snbeams ContractX2d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 380e9132427Snbeams ContractY2d(slice, tidx, tidy, tidz, r_t, s_G, r_V); 3817d8d0e25Snbeams dim = 1; 3827d8d0e25Snbeams writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 3837d8d0e25Snbeams } else { 3847d8d0e25Snbeams dim = 0; 3857d8d0e25Snbeams readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 386e9132427Snbeams ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 387e9132427Snbeams ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, s_G, r_V); 3887d8d0e25Snbeams dim = 1; 3897d8d0e25Snbeams readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 390e9132427Snbeams ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, s_G, r_t); 391e9132427Snbeams ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, s_B, r_U); 3927d8d0e25Snbeams r_V += r_U; 3937d8d0e25Snbeams writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 3947d8d0e25Snbeams } 3957d8d0e25Snbeams } 3967d8d0e25Snbeams } 3977d8d0e25Snbeams 3987d8d0e25Snbeams //------------------------------------------------------------------------------ 3997d8d0e25Snbeams // 2D quadrature weights 4007d8d0e25Snbeams //------------------------------------------------------------------------------ 4017d8d0e25Snbeams __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d, 4027d8d0e25Snbeams CeedScalar *w) { 4037d8d0e25Snbeams const int i = threadIdx.x; 4047d8d0e25Snbeams const int j = threadIdx.y; 4057d8d0e25Snbeams const CeedScalar weight = qweight1d[i]*qweight1d[j]; 4067d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 4077d8d0e25Snbeams elem += gridDim.x*blockDim.z) { 4087d8d0e25Snbeams const int ind = elem*Q1D*Q1D + i + j*Q1D; 4097d8d0e25Snbeams w[ind] = weight; 4107d8d0e25Snbeams } 4117d8d0e25Snbeams } 4127d8d0e25Snbeams 4137d8d0e25Snbeams //------------------------------------------------------------------------------ 4147d8d0e25Snbeams // 3D 4157d8d0e25Snbeams //------------------------------------------------------------------------------ 4167d8d0e25Snbeams 4177d8d0e25Snbeams //------------------------------------------------------------------------------ 4187d8d0e25Snbeams // Read DoFs 4197d8d0e25Snbeams //------------------------------------------------------------------------------ 4207d8d0e25Snbeams inline __device__ void readDofs3d(const int elem, const int tidx, 4217d8d0e25Snbeams const int tidy, const int comp, 4227d8d0e25Snbeams const int nelem, const CeedScalar *d_U, 4237d8d0e25Snbeams CeedScalar *r_U) { 4247d8d0e25Snbeams for (int i = 0; i < P1D; i++) 4257d8d0e25Snbeams r_U[i] = (tidx < P1D && tidy < P1D) ? 4267d8d0e25Snbeams d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D + 4277d8d0e25Snbeams comp*P1D*P1D*P1D*nelem] : 0.0; 4287d8d0e25Snbeams for (int i = P1D; i < Q1D; i++) 4297d8d0e25Snbeams r_U[i] = 0.0; 4307d8d0e25Snbeams } 4317d8d0e25Snbeams 4327d8d0e25Snbeams //------------------------------------------------------------------------------ 4337d8d0e25Snbeams // Write DoFs 4347d8d0e25Snbeams //------------------------------------------------------------------------------ 4357d8d0e25Snbeams inline __device__ void writeDofs3d(const int elem, const int tidx, 4367d8d0e25Snbeams const int tidy, const int comp, 4377d8d0e25Snbeams const int nelem, const CeedScalar *r_V, 4387d8d0e25Snbeams CeedScalar *d_V) { 4397d8d0e25Snbeams if (tidx < P1D && tidy < P1D) { 4407d8d0e25Snbeams for (int i = 0; i < P1D; i++) 4417d8d0e25Snbeams d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D + 4427d8d0e25Snbeams comp*P1D*P1D*P1D*nelem] = r_V[i]; 4437d8d0e25Snbeams } 4447d8d0e25Snbeams } 4457d8d0e25Snbeams 4467d8d0e25Snbeams //------------------------------------------------------------------------------ 4477d8d0e25Snbeams // Read quadrature point data 4487d8d0e25Snbeams //------------------------------------------------------------------------------ 4497d8d0e25Snbeams inline __device__ void readQuads3d(const int elem, const int tidx, 4507d8d0e25Snbeams const int tidy, const int comp, 4517d8d0e25Snbeams const int dim, const int nelem, 4527d8d0e25Snbeams const CeedScalar *d_U, CeedScalar *r_U) { 4537d8d0e25Snbeams for (int i = 0; i < Q1D; i++) 4547d8d0e25Snbeams r_U[i] = (tidx < Q1D && tidy < Q1D) ? 4557d8d0e25Snbeams d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + 4567d8d0e25Snbeams comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0; 4577d8d0e25Snbeams for (int i = Q1D; i < P1D; i++) 4587d8d0e25Snbeams r_U[i] = 0.0; 4597d8d0e25Snbeams } 4607d8d0e25Snbeams 4617d8d0e25Snbeams //------------------------------------------------------------------------------ 4627d8d0e25Snbeams // Write quadrature point data 4637d8d0e25Snbeams //------------------------------------------------------------------------------ 4647d8d0e25Snbeams inline __device__ void writeQuads3d(const int elem, const int tidx, 4657d8d0e25Snbeams const int tidy, const int comp, 4667d8d0e25Snbeams const int dim, const int nelem, 4677d8d0e25Snbeams const CeedScalar *r_V, CeedScalar *d_V) { 4687d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) { 4697d8d0e25Snbeams for (int i = 0; i < Q1D; i++) 4707d8d0e25Snbeams d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem + 4717d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i]; 4727d8d0e25Snbeams } 4737d8d0e25Snbeams } 4747d8d0e25Snbeams 4757d8d0e25Snbeams //------------------------------------------------------------------------------ 4767d8d0e25Snbeams // 3D tensor contract x 4777d8d0e25Snbeams //------------------------------------------------------------------------------ 4787d8d0e25Snbeams inline __device__ void ContractX3d(CeedScalar *slice, const int tidx, 4797d8d0e25Snbeams const int tidy, const int tidz, 4807d8d0e25Snbeams const CeedScalar *U, 4817d8d0e25Snbeams const CeedScalar *B, 4827d8d0e25Snbeams CeedScalar *V) { 4837d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 4847d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 4857d8d0e25Snbeams __syncthreads(); 4867d8d0e25Snbeams V[k] = 0.0; 4877d8d0e25Snbeams if (tidx < Q1D && tidy < P1D) 4887d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 4897d8d0e25Snbeams V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 4907d8d0e25Snbeams __syncthreads(); 4917d8d0e25Snbeams } 4927d8d0e25Snbeams } 4937d8d0e25Snbeams 4947d8d0e25Snbeams //------------------------------------------------------------------------------ 4957d8d0e25Snbeams // 3D tensor contract y 4967d8d0e25Snbeams //------------------------------------------------------------------------------ 4977d8d0e25Snbeams inline __device__ void ContractY3d(CeedScalar *slice, const int tidx, 4987d8d0e25Snbeams const int tidy, const int tidz, 4997d8d0e25Snbeams const CeedScalar *U, 5007d8d0e25Snbeams const CeedScalar *B, 5017d8d0e25Snbeams CeedScalar *V) { 5027d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 5037d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 5047d8d0e25Snbeams __syncthreads(); 5057d8d0e25Snbeams V[k] = 0.0; 5067d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) 5077d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 5087d8d0e25Snbeams V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 5097d8d0e25Snbeams __syncthreads(); 5107d8d0e25Snbeams } 5117d8d0e25Snbeams } 5127d8d0e25Snbeams 5137d8d0e25Snbeams //------------------------------------------------------------------------------ 5147d8d0e25Snbeams // 3D tensor contract z 5157d8d0e25Snbeams //------------------------------------------------------------------------------ 5167d8d0e25Snbeams inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx, 5177d8d0e25Snbeams const int tidy, const int tidz, 5187d8d0e25Snbeams const CeedScalar *U, 5197d8d0e25Snbeams const CeedScalar *B, 5207d8d0e25Snbeams CeedScalar *V) { 5217d8d0e25Snbeams for (int k = 0; k < Q1D; ++k) { 5227d8d0e25Snbeams V[k] = 0.0; 5237d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) 5247d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 5257d8d0e25Snbeams V[k] += B[i + k*P1D] * U[i]; // Contract z direction 5267d8d0e25Snbeams } 5277d8d0e25Snbeams for (int k = Q1D; k < P1D; ++k) 5287d8d0e25Snbeams V[k] = 0.0; 5297d8d0e25Snbeams } 5307d8d0e25Snbeams 5317d8d0e25Snbeams //------------------------------------------------------------------------------ 5327d8d0e25Snbeams // 3D transpose tensor contract z 5337d8d0e25Snbeams //------------------------------------------------------------------------------ 5347d8d0e25Snbeams inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx, 5357d8d0e25Snbeams const int tidy, const int tidz, 5367d8d0e25Snbeams const CeedScalar *U, 5377d8d0e25Snbeams const CeedScalar *B, 5387d8d0e25Snbeams CeedScalar *V) { 5397d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 5407d8d0e25Snbeams V[k] = 0.0; 5417d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) 5427d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 5437d8d0e25Snbeams V[k] += B[k + i*P1D] * U[i]; // Contract z direction 5447d8d0e25Snbeams } 5457d8d0e25Snbeams for (int k = P1D; k < Q1D; ++k) 5467d8d0e25Snbeams V[k] = 0.0; 5477d8d0e25Snbeams } 5487d8d0e25Snbeams 5497d8d0e25Snbeams //------------------------------------------------------------------------------ 5507d8d0e25Snbeams // 3D transpose tensor contract y 5517d8d0e25Snbeams //------------------------------------------------------------------------------ 5527d8d0e25Snbeams inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx, 5537d8d0e25Snbeams const int tidy, const int tidz, 5547d8d0e25Snbeams const CeedScalar *U, 5557d8d0e25Snbeams const CeedScalar *B, 5567d8d0e25Snbeams CeedScalar *V) { 5577d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 5587d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 5597d8d0e25Snbeams __syncthreads(); 5607d8d0e25Snbeams V[k] = 0.0; 5617d8d0e25Snbeams if (tidx < Q1D && tidy < P1D) 5627d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 5637d8d0e25Snbeams V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 5647d8d0e25Snbeams __syncthreads(); 5657d8d0e25Snbeams } 5667d8d0e25Snbeams } 5677d8d0e25Snbeams 5687d8d0e25Snbeams //------------------------------------------------------------------------------ 5697d8d0e25Snbeams // 3D transpose tensor contract x 5707d8d0e25Snbeams //------------------------------------------------------------------------------ 5717d8d0e25Snbeams inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx, 5727d8d0e25Snbeams const int tidy, const int tidz, 5737d8d0e25Snbeams const CeedScalar *U, 5747d8d0e25Snbeams const CeedScalar *B, 5757d8d0e25Snbeams CeedScalar *V) { 5767d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 5777d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 5787d8d0e25Snbeams __syncthreads(); 5797d8d0e25Snbeams V[k] = 0.0; 5807d8d0e25Snbeams if (tidx < P1D && tidy < P1D) 5817d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 5827d8d0e25Snbeams V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 5837d8d0e25Snbeams __syncthreads(); 5847d8d0e25Snbeams } 5857d8d0e25Snbeams } 5867d8d0e25Snbeams 5877d8d0e25Snbeams //------------------------------------------------------------------------------ 5887d8d0e25Snbeams // 3D interpolate to quadrature points 5897d8d0e25Snbeams //------------------------------------------------------------------------------ 5907d8d0e25Snbeams inline __device__ void interp3d(const CeedInt nelem, const int transpose, 591e9132427Snbeams const CeedScalar *s_B, 5927d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 5937d8d0e25Snbeams CeedScalar *__restrict__ d_V, 5947d8d0e25Snbeams CeedScalar *slice) { 5957d8d0e25Snbeams CeedScalar r_V[T1D]; 5967d8d0e25Snbeams CeedScalar r_t[T1D]; 5977d8d0e25Snbeams 5987d8d0e25Snbeams const int tidx = threadIdx.x; 5997d8d0e25Snbeams const int tidy = threadIdx.y; 6007d8d0e25Snbeams const int tidz = threadIdx.z; 6017d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 6027d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 6037d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 6047d8d0e25Snbeams 6057d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 6067d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 6077d8d0e25Snbeams for (int i = 0; i < T1D; ++i) { 6087d8d0e25Snbeams r_V[i] = 0.0; 6097d8d0e25Snbeams r_t[i] = 0.0; 6107d8d0e25Snbeams } 6117d8d0e25Snbeams if (!transpose) { 6127d8d0e25Snbeams readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V); 613e9132427Snbeams ContractX3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 614e9132427Snbeams ContractY3d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 615e9132427Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 6167d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V); 6177d8d0e25Snbeams } else { 6187d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 619e9132427Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 620e9132427Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 621e9132427Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 6227d8d0e25Snbeams writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V); 6237d8d0e25Snbeams } 6247d8d0e25Snbeams } 6257d8d0e25Snbeams } 6267d8d0e25Snbeams 6277d8d0e25Snbeams //------------------------------------------------------------------------------ 6287d8d0e25Snbeams // 3D derivatives at quadrature points 6297d8d0e25Snbeams //------------------------------------------------------------------------------ 6307d8d0e25Snbeams inline __device__ void grad3d(const CeedInt nelem, const int transpose, 631e9132427Snbeams const CeedScalar *s_B, const CeedScalar *s_G, 6327d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 6337d8d0e25Snbeams CeedScalar *__restrict__ d_V, 6347d8d0e25Snbeams CeedScalar *slice) { 6357d8d0e25Snbeams // Use P1D for one of these 6367d8d0e25Snbeams CeedScalar r_U[T1D]; 6377d8d0e25Snbeams CeedScalar r_V[T1D]; 6387d8d0e25Snbeams CeedScalar r_t[T1D]; 6397d8d0e25Snbeams 6407d8d0e25Snbeams const int tidx = threadIdx.x; 6417d8d0e25Snbeams const int tidy = threadIdx.y; 6427d8d0e25Snbeams const int tidz = threadIdx.z; 6437d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 6447d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 6457d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 6467d8d0e25Snbeams int dim; 6477d8d0e25Snbeams 6487d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 6497d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 6507d8d0e25Snbeams for (int i = 0; i < T1D; ++i) { 6517d8d0e25Snbeams r_U[i] = 0.0; 6527d8d0e25Snbeams r_V[i] = 0.0; 6537d8d0e25Snbeams r_t[i] = 0.0; 6547d8d0e25Snbeams } 6557d8d0e25Snbeams if (!transpose) { 6567d8d0e25Snbeams readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U); 657e9132427Snbeams ContractX3d(slice, tidx, tidy, tidz, r_U, s_G, r_V); 658e9132427Snbeams ContractY3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 659e9132427Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 6607d8d0e25Snbeams dim = 0; 6617d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 662e9132427Snbeams ContractX3d(slice, tidx, tidy, tidz, r_U, s_B, r_V); 663e9132427Snbeams ContractY3d(slice, tidx, tidy, tidz, r_V, s_G, r_t); 664e9132427Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_t, s_B, r_V); 6657d8d0e25Snbeams dim = 1; 6667d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 667e9132427Snbeams ContractX3d(slice, tidx, tidy, tidz, r_U, s_B, r_V); 668e9132427Snbeams ContractY3d(slice, tidx, tidy, tidz, r_V, s_B, r_t); 669e9132427Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_t, s_G, r_V); 6707d8d0e25Snbeams dim = 2; 6717d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 6727d8d0e25Snbeams } else { 6737d8d0e25Snbeams dim = 0; 6747d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 675e9132427Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 676e9132427Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_B, r_U); 677e9132427Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, s_G, r_V); 6787d8d0e25Snbeams dim = 1; 6797d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 680e9132427Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 681e9132427Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_G, r_U); 682e9132427Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 6837d8d0e25Snbeams add(r_V, r_t); 6847d8d0e25Snbeams dim = 2; 6857d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 686e9132427Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, s_G, r_t); 687e9132427Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_B, r_U); 688e9132427Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, s_B, r_t); 6897d8d0e25Snbeams add(r_V, r_t); 6907d8d0e25Snbeams writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V); 6917d8d0e25Snbeams } 6927d8d0e25Snbeams } 6937d8d0e25Snbeams } 6947d8d0e25Snbeams 6957d8d0e25Snbeams //------------------------------------------------------------------------------ 6967d8d0e25Snbeams // 3D quadrature weights 6977d8d0e25Snbeams //------------------------------------------------------------------------------ 6987d8d0e25Snbeams __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d, 6997d8d0e25Snbeams CeedScalar *w) { 7007d8d0e25Snbeams const int i = threadIdx.x; 7017d8d0e25Snbeams const int j = threadIdx.y; 7027d8d0e25Snbeams const int k = threadIdx.z; 7037d8d0e25Snbeams const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k]; 7047d8d0e25Snbeams for (int e = blockIdx.x; e < nelem; e += gridDim.x) { 7057d8d0e25Snbeams const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D; 7067d8d0e25Snbeams w[ind] = weight; 7077d8d0e25Snbeams } 7087d8d0e25Snbeams } 7097d8d0e25Snbeams 7107d8d0e25Snbeams //------------------------------------------------------------------------------ 7117d8d0e25Snbeams // Basis kernels 7127d8d0e25Snbeams //------------------------------------------------------------------------------ 7137d8d0e25Snbeams 7147d8d0e25Snbeams //------------------------------------------------------------------------------ 7157d8d0e25Snbeams // Interp kernel by dim 7167d8d0e25Snbeams //------------------------------------------------------------------------------ 7179e31c45bSnbeams extern "C" __launch_bounds__(INTERP_BLKSIZE) __global__ void interp( 7189e31c45bSnbeams const CeedInt nelem, const int transpose, 7199dd88646Snbeams CeedScalar *d_interp1d, 7207d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 7217d8d0e25Snbeams CeedScalar *__restrict__ d_V) { 7229dd88646Snbeams 7237d8d0e25Snbeams HIP_DYNAMIC_SHARED( double, slice) 7249dd88646Snbeams // load interp1d into shared memory 725e9132427Snbeams __shared__ double s_B[P1D*Q1D]; 726e9132427Snbeams loadMatrix(d_interp1d, s_B); 72755250509Snbeams __syncthreads(); 7289dd88646Snbeams 7297d8d0e25Snbeams if (BASIS_DIM == 1) { 730e9132427Snbeams interp1d(nelem, transpose, s_B, d_U, d_V, slice); 7317d8d0e25Snbeams } else if (BASIS_DIM == 2) { 732e9132427Snbeams interp2d(nelem, transpose, s_B, d_U, d_V, slice); 7337d8d0e25Snbeams } else if (BASIS_DIM == 3) { 734e9132427Snbeams interp3d(nelem, transpose, s_B, d_U, d_V, slice); 7357d8d0e25Snbeams } 7367d8d0e25Snbeams } 7377d8d0e25Snbeams 7387d8d0e25Snbeams //------------------------------------------------------------------------------ 7397d8d0e25Snbeams // Grad kernel by dim 7407d8d0e25Snbeams //------------------------------------------------------------------------------ 7419e31c45bSnbeams extern "C" __launch_bounds__(GRAD_BLKSIZE) __global__ void grad(const CeedInt nelem, 7429e31c45bSnbeams const int transpose, 7439dd88646Snbeams CeedScalar *d_interp1d, CeedScalar *d_grad1d, 7447d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 7457d8d0e25Snbeams CeedScalar *__restrict__ d_V) { 7467d8d0e25Snbeams HIP_DYNAMIC_SHARED( double, slice) 7479dd88646Snbeams // load interp1d and grad1d into shared memory 748e9132427Snbeams __shared__ double s_B[P1D*Q1D]; 749e9132427Snbeams loadMatrix(d_interp1d, s_B); 750e9132427Snbeams __shared__ double s_G[P1D*Q1D]; 751e9132427Snbeams loadMatrix(d_grad1d, s_G); 75255250509Snbeams __syncthreads(); 7539dd88646Snbeams 7547d8d0e25Snbeams if (BASIS_DIM == 1) { 755e9132427Snbeams grad1d(nelem, transpose, s_B, s_G, d_U, d_V, slice); 7567d8d0e25Snbeams } else if (BASIS_DIM == 2) { 757e9132427Snbeams grad2d(nelem, transpose, s_B, s_G, d_U, d_V, slice); 7587d8d0e25Snbeams } else if (BASIS_DIM == 3) { 759e9132427Snbeams grad3d(nelem, transpose, s_B, s_G, d_U, d_V, slice); 7607d8d0e25Snbeams } 7617d8d0e25Snbeams } 7627d8d0e25Snbeams 7637d8d0e25Snbeams //------------------------------------------------------------------------------ 7647d8d0e25Snbeams // Weight kernels by dim 7657d8d0e25Snbeams //------------------------------------------------------------------------------ 7669e31c45bSnbeams extern "C" __launch_bounds__(WEIGHT_BLKSIZE) __global__ void weight(const CeedInt nelem, 7677d8d0e25Snbeams const CeedScalar *__restrict__ qweight1d, 7687d8d0e25Snbeams CeedScalar *__restrict__ v) { 7697d8d0e25Snbeams if (BASIS_DIM == 1) { 7707d8d0e25Snbeams weight1d(nelem, qweight1d, v); 7717d8d0e25Snbeams } else if (BASIS_DIM == 2) { 7727d8d0e25Snbeams weight2d(nelem, qweight1d, v); 7737d8d0e25Snbeams } else if (BASIS_DIM == 3) { 7747d8d0e25Snbeams weight3d(nelem, qweight1d, v); 7757d8d0e25Snbeams } 7767d8d0e25Snbeams } 7777d8d0e25Snbeams 7787d8d0e25Snbeams ); 7797d8d0e25Snbeams // *INDENT-ON* 7807d8d0e25Snbeams 7817d8d0e25Snbeams //------------------------------------------------------------------------------ 7829e31c45bSnbeams // Compute a block size based on required minimum threads 7839e31c45bSnbeams //------------------------------------------------------------------------------ 7849e31c45bSnbeams static CeedInt ComputeBlockSizeFromRequirement(const CeedInt required) { 7859e31c45bSnbeams CeedInt maxSize = 1024; // Max total threads per block 7869e31c45bSnbeams CeedInt currentSize = 64; // Start with one group 7879e31c45bSnbeams 7889e31c45bSnbeams while(currentSize < maxSize) { 7899e31c45bSnbeams if (currentSize > required) 7909e31c45bSnbeams break; 7919e31c45bSnbeams else 7929e31c45bSnbeams currentSize = currentSize * 2; 7939e31c45bSnbeams } 7949e31c45bSnbeams return currentSize; 7959e31c45bSnbeams } 7969e31c45bSnbeams 7979e31c45bSnbeams //------------------------------------------------------------------------------ 7989e31c45bSnbeams // Compute required thread block sizes for basis kernels given P, Q, dim, and 7999e31c45bSnbeams // ncomp 8009e31c45bSnbeams //------------------------------------------------------------------------------ 8019e31c45bSnbeams static int ComputeBasisThreadBlockSizes(const CeedInt dim, const CeedInt P1d, 8029e31c45bSnbeams const CeedInt Q1d, 8039e31c45bSnbeams const CeedInt ncomp, CeedInt *blksizes) { 8049e31c45bSnbeams 8059e31c45bSnbeams // Note that this will use the same block sizes for all dimensions when compiling, 8069e31c45bSnbeams // but as each basis object is defined for a particular dimension, we will never 8079e31c45bSnbeams // call any kernels except the ones for the dimension for which we have computed the 8089e31c45bSnbeams // block sizes. 8099e31c45bSnbeams const CeedInt thread1d = CeedIntMax(P1d, Q1d); 8109e31c45bSnbeams switch (dim) { 8119e31c45bSnbeams case 1: { 8129e31c45bSnbeams // Interp kernels: 8139e31c45bSnbeams blksizes[0] = 256; 8149e31c45bSnbeams 8159e31c45bSnbeams // Grad kernels: 8169e31c45bSnbeams blksizes[1] = 256; 8179e31c45bSnbeams 8189e31c45bSnbeams // Weight kernels: 8199e31c45bSnbeams blksizes[2] = 256; 8209e31c45bSnbeams 8219e31c45bSnbeams } break; 8229e31c45bSnbeams case 2: { 8239e31c45bSnbeams // Interp kernels: 8249e31c45bSnbeams CeedInt required = thread1d * thread1d * ncomp; 8259e31c45bSnbeams blksizes[0] = ComputeBlockSizeFromRequirement(required); 8269e31c45bSnbeams 8279e31c45bSnbeams // Grad kernels: currently use same required minimum threads 8289e31c45bSnbeams blksizes[1] = ComputeBlockSizeFromRequirement(required); 8299e31c45bSnbeams 8309e31c45bSnbeams // Weight kernels: 8319e31c45bSnbeams required = CeedIntMax(64, Q1d * Q1d); 8329e31c45bSnbeams blksizes[2] = ComputeBlockSizeFromRequirement(required); 8339e31c45bSnbeams 8349e31c45bSnbeams } break; 8359e31c45bSnbeams case 3: { 8369e31c45bSnbeams // Interp kernels: 8379e31c45bSnbeams CeedInt required = thread1d * thread1d * ncomp; 8389e31c45bSnbeams blksizes[0] = ComputeBlockSizeFromRequirement(required); 8399e31c45bSnbeams 8409e31c45bSnbeams // Grad kernels: currently use same required minimum threads 8419e31c45bSnbeams blksizes[1] = ComputeBlockSizeFromRequirement(required); 8429e31c45bSnbeams 8439e31c45bSnbeams // Weight kernels: 8449e31c45bSnbeams required = Q1d * Q1d * Q1d; 8459e31c45bSnbeams blksizes[2] = ComputeBlockSizeFromRequirement(required); 8469e31c45bSnbeams } 8479e31c45bSnbeams } 8489e31c45bSnbeams 849e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 8509e31c45bSnbeams } 8519e31c45bSnbeams 8529e31c45bSnbeams //------------------------------------------------------------------------------ 8537d8d0e25Snbeams // Apply basis 8547d8d0e25Snbeams //------------------------------------------------------------------------------ 8557d8d0e25Snbeams int CeedBasisApplyTensor_Hip_shared(CeedBasis basis, const CeedInt nelem, 8567d8d0e25Snbeams CeedTransposeMode tmode, 8577d8d0e25Snbeams CeedEvalMode emode, CeedVector u, 8587d8d0e25Snbeams CeedVector v) { 8597d8d0e25Snbeams int ierr; 8607d8d0e25Snbeams Ceed ceed; 861e15f9bd0SJeremy L Thompson ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr); 862*6dbfb411Snbeams Ceed_Hip *ceed_Hip; 863e15f9bd0SJeremy L Thompson CeedGetData(ceed, &ceed_Hip); CeedChkBackend(ierr); 8647d8d0e25Snbeams CeedBasis_Hip_shared *data; 865e15f9bd0SJeremy L Thompson CeedBasisGetData(basis, &data); CeedChkBackend(ierr); 8667d8d0e25Snbeams const CeedInt transpose = tmode == CEED_TRANSPOSE; 8677d8d0e25Snbeams CeedInt dim, ncomp; 868e15f9bd0SJeremy L Thompson ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr); 869e15f9bd0SJeremy L Thompson ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr); 8707d8d0e25Snbeams 8717d8d0e25Snbeams // Read vectors 8727d8d0e25Snbeams const CeedScalar *d_u; 8737d8d0e25Snbeams CeedScalar *d_v; 8747d8d0e25Snbeams if (emode != CEED_EVAL_WEIGHT) { 875e15f9bd0SJeremy L Thompson ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChkBackend(ierr); 8767d8d0e25Snbeams } 877e15f9bd0SJeremy L Thompson ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr); 8787d8d0e25Snbeams 8797d8d0e25Snbeams // Clear v for transpose mode 8807d8d0e25Snbeams if (tmode == CEED_TRANSPOSE) { 8817d8d0e25Snbeams CeedInt length; 882e15f9bd0SJeremy L Thompson ierr = CeedVectorGetLength(v, &length); CeedChkBackend(ierr); 883e15f9bd0SJeremy L Thompson ierr = hipMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChkBackend(ierr); 8847d8d0e25Snbeams } 8857d8d0e25Snbeams 8867d8d0e25Snbeams // Apply basis operation 8877d8d0e25Snbeams switch (emode) { 8887d8d0e25Snbeams case CEED_EVAL_INTERP: { 8897d8d0e25Snbeams CeedInt P1d, Q1d; 8909e31c45bSnbeams CeedInt blksize = data->blksizes[0]; 891e15f9bd0SJeremy L Thompson ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr); 892e15f9bd0SJeremy L Thompson ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr); 8937d8d0e25Snbeams CeedInt thread1d = CeedIntMax(Q1d, P1d); 8949dd88646Snbeams void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->d_interp1d, 8957d8d0e25Snbeams &d_u, &d_v 8967d8d0e25Snbeams }; 8977d8d0e25Snbeams if (dim == 1) { 898e7ea6884Snbeams CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64; 8997d8d0e25Snbeams elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1; 9007d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 9017d8d0e25Snbeams ? 1 : 0 ); 9027d8d0e25Snbeams CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar); 9037d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, 1, 9047d8d0e25Snbeams elemsPerBlock, sharedMem, 905e15f9bd0SJeremy L Thompson interpargs); CeedChkBackend(ierr); 9067d8d0e25Snbeams } else if (dim == 2) { 9079e31c45bSnbeams // Check if required threads is small enough to do multiple elems 9089e31c45bSnbeams const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1); 9097d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 9107d8d0e25Snbeams ? 1 : 0 ); 9117d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 9127d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d, 9137d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 914e15f9bd0SJeremy L Thompson interpargs); CeedChkBackend(ierr); 9157d8d0e25Snbeams } else if (dim == 3) { 9167d8d0e25Snbeams CeedInt elemsPerBlock = 1; 9177d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 9187d8d0e25Snbeams ? 1 : 0 ); 9197d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 9207d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d, 9217d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 922e15f9bd0SJeremy L Thompson interpargs); CeedChkBackend(ierr); 9237d8d0e25Snbeams } 9247d8d0e25Snbeams } break; 9257d8d0e25Snbeams case CEED_EVAL_GRAD: { 9267d8d0e25Snbeams CeedInt P1d, Q1d; 9279e31c45bSnbeams CeedInt blksize = data->blksizes[1]; 928e15f9bd0SJeremy L Thompson ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr); 929e15f9bd0SJeremy L Thompson ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr); 9307d8d0e25Snbeams CeedInt thread1d = CeedIntMax(Q1d, P1d); 9319dd88646Snbeams void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->d_interp1d, 9329dd88646Snbeams &data->d_grad1d, &d_u, &d_v 9337d8d0e25Snbeams }; 9347d8d0e25Snbeams if (dim == 1) { 935e7ea6884Snbeams CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64; 9367d8d0e25Snbeams elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1; 9377d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 9387d8d0e25Snbeams ? 1 : 0 ); 9397d8d0e25Snbeams CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar); 9407d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, 1, 9417d8d0e25Snbeams elemsPerBlock, sharedMem, gradargs); 942e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 9437d8d0e25Snbeams } else if (dim == 2) { 9449e31c45bSnbeams // Check if required threads is small enough to do multiple elems 9459e31c45bSnbeams const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1); 9467d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 9477d8d0e25Snbeams ? 1 : 0 ); 9487d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 9497d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d, 9507d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 951e15f9bd0SJeremy L Thompson gradargs); CeedChkBackend(ierr); 9527d8d0e25Snbeams } else if (dim == 3) { 9537d8d0e25Snbeams CeedInt elemsPerBlock = 1; 9547d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 9557d8d0e25Snbeams ? 1 : 0 ); 9567d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 9577d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d, 9587d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 959e15f9bd0SJeremy L Thompson gradargs); CeedChkBackend(ierr); 9607d8d0e25Snbeams } 9617d8d0e25Snbeams } break; 9627d8d0e25Snbeams case CEED_EVAL_WEIGHT: { 9637d8d0e25Snbeams CeedInt Q1d; 9649e31c45bSnbeams CeedInt blksize = data->blksizes[2]; 965e15f9bd0SJeremy L Thompson ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr); 9667d8d0e25Snbeams void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v}; 9677d8d0e25Snbeams if (dim == 1) { 9689e31c45bSnbeams const CeedInt optElems = blksize/Q1d; 9697d8d0e25Snbeams const CeedInt elemsPerBlock = optElems>0?optElems:1; 9707d8d0e25Snbeams const CeedInt gridsize = nelem/elemsPerBlock + ( ( 9717d8d0e25Snbeams nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 9727d8d0e25Snbeams ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, 9737d8d0e25Snbeams elemsPerBlock, 1, weightargs); 974e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 9757d8d0e25Snbeams } else if (dim == 2) { 9769e31c45bSnbeams const CeedInt optElems = blksize/(Q1d*Q1d); 9777d8d0e25Snbeams const CeedInt elemsPerBlock = optElems>0?optElems:1; 9787d8d0e25Snbeams const CeedInt gridsize = nelem/elemsPerBlock + ( ( 9797d8d0e25Snbeams nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 9807d8d0e25Snbeams ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, 9817d8d0e25Snbeams elemsPerBlock, weightargs); 982e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 9837d8d0e25Snbeams } else if (dim == 3) { 9847d8d0e25Snbeams const CeedInt gridsize = nelem; 9857d8d0e25Snbeams ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, Q1d, 9867d8d0e25Snbeams weightargs); 987e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 9887d8d0e25Snbeams } 9897d8d0e25Snbeams } break; 9907d8d0e25Snbeams // LCOV_EXCL_START 9917d8d0e25Snbeams // Evaluate the divergence to/from the quadrature points 9927d8d0e25Snbeams case CEED_EVAL_DIV: 993e15f9bd0SJeremy L Thompson return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported"); 9947d8d0e25Snbeams // Evaluate the curl to/from the quadrature points 9957d8d0e25Snbeams case CEED_EVAL_CURL: 996e15f9bd0SJeremy L Thompson return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported"); 9977d8d0e25Snbeams // Take no action, BasisApply should not have been called 9987d8d0e25Snbeams case CEED_EVAL_NONE: 999e15f9bd0SJeremy L Thompson return CeedError(ceed, CEED_ERROR_BACKEND, 10007d8d0e25Snbeams "CEED_EVAL_NONE does not make sense in this context"); 10017d8d0e25Snbeams // LCOV_EXCL_STOP 10027d8d0e25Snbeams } 10037d8d0e25Snbeams 10047d8d0e25Snbeams // Restore vectors 10057d8d0e25Snbeams if (emode != CEED_EVAL_WEIGHT) { 1006e15f9bd0SJeremy L Thompson ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChkBackend(ierr); 10077d8d0e25Snbeams } 1008e15f9bd0SJeremy L Thompson ierr = CeedVectorRestoreArray(v, &d_v); CeedChkBackend(ierr); 1009e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 10107d8d0e25Snbeams } 10117d8d0e25Snbeams 10127d8d0e25Snbeams //------------------------------------------------------------------------------ 10137d8d0e25Snbeams // Destroy basis 10147d8d0e25Snbeams //------------------------------------------------------------------------------ 10157d8d0e25Snbeams static int CeedBasisDestroy_Hip_shared(CeedBasis basis) { 10167d8d0e25Snbeams int ierr; 10177d8d0e25Snbeams Ceed ceed; 1018e15f9bd0SJeremy L Thompson ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr); 10197d8d0e25Snbeams 10207d8d0e25Snbeams CeedBasis_Hip_shared *data; 1021e15f9bd0SJeremy L Thompson ierr = CeedBasisGetData(basis, &data); CeedChkBackend(ierr); 10227d8d0e25Snbeams 10237d8d0e25Snbeams CeedChk_Hip(ceed, hipModuleUnload(data->module)); 10247d8d0e25Snbeams 10257d8d0e25Snbeams ierr = hipFree(data->d_qweight1d); CeedChk_Hip(ceed, ierr); 10267d8d0e25Snbeams ierr = hipFree(data->d_interp1d); CeedChk_Hip(ceed, ierr); 10277d8d0e25Snbeams ierr = hipFree(data->d_grad1d); CeedChk_Hip(ceed, ierr); 10287d8d0e25Snbeams ierr = hipFree(data->d_collograd1d); CeedChk_Hip(ceed, ierr); 10297d8d0e25Snbeams 1030e15f9bd0SJeremy L Thompson ierr = CeedFree(&data); CeedChkBackend(ierr); 10317d8d0e25Snbeams 1032e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 10337d8d0e25Snbeams } 10347d8d0e25Snbeams 10357d8d0e25Snbeams //------------------------------------------------------------------------------ 10367d8d0e25Snbeams // Create tensor basis 10377d8d0e25Snbeams //------------------------------------------------------------------------------ 10387d8d0e25Snbeams int CeedBasisCreateTensorH1_Hip_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d, 10397d8d0e25Snbeams const CeedScalar *interp1d, 10407d8d0e25Snbeams const CeedScalar *grad1d, 10417d8d0e25Snbeams const CeedScalar *qref1d, 10427d8d0e25Snbeams const CeedScalar *qweight1d, 10437d8d0e25Snbeams CeedBasis basis) { 10447d8d0e25Snbeams int ierr; 10457d8d0e25Snbeams Ceed ceed; 1046e15f9bd0SJeremy L Thompson ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr); 10477d8d0e25Snbeams CeedBasis_Hip_shared *data; 1048e15f9bd0SJeremy L Thompson ierr = CeedCalloc(1, &data); CeedChkBackend(ierr); 10497d8d0e25Snbeams 10507d8d0e25Snbeams // Copy basis data to GPU 10517d8d0e25Snbeams const CeedInt qBytes = Q1d * sizeof(CeedScalar); 10527d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Hip(ceed, ierr); 10537d8d0e25Snbeams ierr = hipMemcpy(data->d_qweight1d, qweight1d, qBytes, 10547d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 10557d8d0e25Snbeams 10567d8d0e25Snbeams const CeedInt iBytes = qBytes * P1d; 10577d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Hip(ceed, ierr); 10587d8d0e25Snbeams ierr = hipMemcpy(data->d_interp1d, interp1d, iBytes, 10597d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 10607d8d0e25Snbeams 10617d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Hip(ceed, ierr); 10627d8d0e25Snbeams ierr = hipMemcpy(data->d_grad1d, grad1d, iBytes, 10637d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 10647d8d0e25Snbeams 10657d8d0e25Snbeams // Compute collocated gradient and copy to GPU 10667d8d0e25Snbeams data->d_collograd1d = NULL; 10677d8d0e25Snbeams if (dim == 3 && Q1d >= P1d) { 10687d8d0e25Snbeams CeedScalar *collograd1d; 1069e15f9bd0SJeremy L Thompson ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChkBackend(ierr); 1070e15f9bd0SJeremy L Thompson ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChkBackend(ierr); 10717d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_collograd1d, qBytes * Q1d); 10727d8d0e25Snbeams CeedChk_Hip(ceed, ierr); 10737d8d0e25Snbeams ierr = hipMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d, 10747d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 1075e15f9bd0SJeremy L Thompson ierr = CeedFree(&collograd1d); CeedChkBackend(ierr); 10767d8d0e25Snbeams } 10777d8d0e25Snbeams 10789e31c45bSnbeams // Set number of threads per block for basis kernels 10797d8d0e25Snbeams CeedInt ncomp; 1080e15f9bd0SJeremy L Thompson ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr); 10819e31c45bSnbeams ierr = ComputeBasisThreadBlockSizes(dim, P1d, Q1d, ncomp, data->blksizes); 1082e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 10839e31c45bSnbeams 10849e31c45bSnbeams // Compile basis kernels 10859e31c45bSnbeams ierr = CeedCompileHip(ceed, kernelsShared, &data->module, 11, 10867d8d0e25Snbeams "Q1D", Q1d, 10877d8d0e25Snbeams "P1D", P1d, 10887d8d0e25Snbeams "T1D", CeedIntMax(Q1d, P1d), 10897d8d0e25Snbeams "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ? 10907d8d0e25Snbeams Q1d : P1d, dim), 10917d8d0e25Snbeams "BASIS_DIM", dim, 10927d8d0e25Snbeams "BASIS_NCOMP", ncomp, 10937d8d0e25Snbeams "BASIS_ELEMSIZE", CeedIntPow(P1d, dim), 10949e31c45bSnbeams "BASIS_NQPT", CeedIntPow(Q1d, dim), 10959e31c45bSnbeams "INTERP_BLKSIZE", data->blksizes[0], 10969e31c45bSnbeams "GRAD_BLKSIZE", data->blksizes[1], 10979e31c45bSnbeams "WEIGHT_BLKSIZE", data->blksizes[2] 1098e15f9bd0SJeremy L Thompson ); CeedChkBackend(ierr); 10997d8d0e25Snbeams ierr = CeedGetKernelHip(ceed, data->module, "interp", &data->interp); 1100e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 11017d8d0e25Snbeams ierr = CeedGetKernelHip(ceed, data->module, "grad", &data->grad); 1102e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 11037d8d0e25Snbeams ierr = CeedGetKernelHip(ceed, data->module, "weight", &data->weight); 1104e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 11057d8d0e25Snbeams 1106e15f9bd0SJeremy L Thompson ierr = CeedBasisSetData(basis, data); CeedChkBackend(ierr); 11077d8d0e25Snbeams 11087d8d0e25Snbeams // Register backend functions 11097d8d0e25Snbeams ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply", 11107d8d0e25Snbeams CeedBasisApplyTensor_Hip_shared); 1111e15f9bd0SJeremy L Thompson CeedChkBackend(ierr); 11127d8d0e25Snbeams ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", 1113e15f9bd0SJeremy L Thompson CeedBasisDestroy_Hip_shared); CeedChkBackend(ierr); 1114e15f9bd0SJeremy L Thompson return CEED_ERROR_SUCCESS; 11157d8d0e25Snbeams } 11167d8d0e25Snbeams //------------------------------------------------------------------------------ 1117