1*7d8d0e25Snbeams // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC. 2*7d8d0e25Snbeams // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707. 3*7d8d0e25Snbeams // All Rights reserved. See files LICENSE and NOTICE for details. 4*7d8d0e25Snbeams // 5*7d8d0e25Snbeams // This file is part of CEED, a collection of benchmarks, miniapps, software 6*7d8d0e25Snbeams // libraries and APIs for efficient high-order finite element and spectral 7*7d8d0e25Snbeams // element discretizations for exascale applications. For more information and 8*7d8d0e25Snbeams // source code availability see http://github.com/ceed. 9*7d8d0e25Snbeams // 10*7d8d0e25Snbeams // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 11*7d8d0e25Snbeams // a collaborative effort of two U.S. Department of Energy organizations (Office 12*7d8d0e25Snbeams // of Science and the National Nuclear Security Administration) responsible for 13*7d8d0e25Snbeams // the planning and preparation of a capable exascale ecosystem, including 14*7d8d0e25Snbeams // software, applications, hardware, advanced system engineering and early 15*7d8d0e25Snbeams // testbed platforms, in support of the nation's exascale computing imperative. 16*7d8d0e25Snbeams 17*7d8d0e25Snbeams #include "ceed-hip-shared.h" 18*7d8d0e25Snbeams #include "../hip/ceed-hip-compile.h" 19*7d8d0e25Snbeams 20*7d8d0e25Snbeams //------------------------------------------------------------------------------ 21*7d8d0e25Snbeams // Shared mem kernels 22*7d8d0e25Snbeams //------------------------------------------------------------------------------ 23*7d8d0e25Snbeams // *INDENT-OFF* 24*7d8d0e25Snbeams static const char *kernelsShared = QUOTE( 25*7d8d0e25Snbeams 26*7d8d0e25Snbeams //------------------------------------------------------------------------------ 27*7d8d0e25Snbeams // Sum input into output 28*7d8d0e25Snbeams //------------------------------------------------------------------------------ 29*7d8d0e25Snbeams inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) { 30*7d8d0e25Snbeams for (int i = 0; i < P1D; i++) 31*7d8d0e25Snbeams r_V[i] += r_U[i]; 32*7d8d0e25Snbeams } 33*7d8d0e25Snbeams 34*7d8d0e25Snbeams //------------------------------------------------------------------------------ 35*7d8d0e25Snbeams // 1D 36*7d8d0e25Snbeams //------------------------------------------------------------------------------ 37*7d8d0e25Snbeams 38*7d8d0e25Snbeams //------------------------------------------------------------------------------ 39*7d8d0e25Snbeams // Read DoFs 40*7d8d0e25Snbeams //------------------------------------------------------------------------------ 41*7d8d0e25Snbeams inline __device__ void readDofs1d(const int elem, const int tidx, 42*7d8d0e25Snbeams const int tidy, const int tidz,const int comp, 43*7d8d0e25Snbeams const int nelem, const CeedScalar *d_U, 44*7d8d0e25Snbeams CeedScalar *slice) { 45*7d8d0e25Snbeams for (int i = 0; i < P1D; i++) 46*7d8d0e25Snbeams slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem]; 47*7d8d0e25Snbeams for (int i = P1D; i < Q1D; i++) 48*7d8d0e25Snbeams slice[i + tidz*T1D] = 0.0; 49*7d8d0e25Snbeams } 50*7d8d0e25Snbeams 51*7d8d0e25Snbeams //------------------------------------------------------------------------------ 52*7d8d0e25Snbeams // Write DoFs 53*7d8d0e25Snbeams //------------------------------------------------------------------------------ 54*7d8d0e25Snbeams inline __device__ void writeDofs1d(const int elem, const int tidx, 55*7d8d0e25Snbeams const int tidy, const int comp, 56*7d8d0e25Snbeams const int nelem, const CeedScalar &r_V, 57*7d8d0e25Snbeams CeedScalar *d_V) { 58*7d8d0e25Snbeams if (tidx<P1D) 59*7d8d0e25Snbeams d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V; 60*7d8d0e25Snbeams } 61*7d8d0e25Snbeams 62*7d8d0e25Snbeams //------------------------------------------------------------------------------ 63*7d8d0e25Snbeams // Read quadrature point data 64*7d8d0e25Snbeams //------------------------------------------------------------------------------ 65*7d8d0e25Snbeams inline __device__ void readQuads1d(const int elem, const int tidx, 66*7d8d0e25Snbeams const int tidy, const int tidz, const int comp, 67*7d8d0e25Snbeams const int dim, const int nelem, 68*7d8d0e25Snbeams const CeedScalar *d_U, CeedScalar *slice) { 69*7d8d0e25Snbeams for (int i = 0; i < Q1D; i++) 70*7d8d0e25Snbeams slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem + 71*7d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D]; 72*7d8d0e25Snbeams for (int i = Q1D; i < P1D; i++) 73*7d8d0e25Snbeams slice[i + tidz*T1D] = 0.0; 74*7d8d0e25Snbeams } 75*7d8d0e25Snbeams 76*7d8d0e25Snbeams //------------------------------------------------------------------------------ 77*7d8d0e25Snbeams // Write quadrature point data 78*7d8d0e25Snbeams //------------------------------------------------------------------------------ 79*7d8d0e25Snbeams inline __device__ void writeQuads1d(const int elem, const int tidx, 80*7d8d0e25Snbeams const int tidy, const int comp, 81*7d8d0e25Snbeams const int dim, const int nelem, 82*7d8d0e25Snbeams const CeedScalar &r_V, CeedScalar *d_V) { 83*7d8d0e25Snbeams if (tidx<Q1D) 84*7d8d0e25Snbeams d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V; 85*7d8d0e25Snbeams } 86*7d8d0e25Snbeams 87*7d8d0e25Snbeams //------------------------------------------------------------------------------ 88*7d8d0e25Snbeams // 1D tensor contraction 89*7d8d0e25Snbeams //------------------------------------------------------------------------------ 90*7d8d0e25Snbeams inline __device__ void ContractX1d(CeedScalar *slice, const int tidx, 91*7d8d0e25Snbeams const int tidy, const int tidz, 92*7d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, 93*7d8d0e25Snbeams CeedScalar &V) { 94*7d8d0e25Snbeams V = 0.0; 95*7d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 96*7d8d0e25Snbeams V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction 97*7d8d0e25Snbeams } 98*7d8d0e25Snbeams 99*7d8d0e25Snbeams //------------------------------------------------------------------------------ 100*7d8d0e25Snbeams // 1D transpose tensor contraction 101*7d8d0e25Snbeams //------------------------------------------------------------------------------ 102*7d8d0e25Snbeams inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx, 103*7d8d0e25Snbeams const int tidy, const int tidz, 104*7d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 105*7d8d0e25Snbeams V = 0.0; 106*7d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 107*7d8d0e25Snbeams V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction 108*7d8d0e25Snbeams } 109*7d8d0e25Snbeams 110*7d8d0e25Snbeams //------------------------------------------------------------------------------ 111*7d8d0e25Snbeams // 1D interpolate to quadrature points 112*7d8d0e25Snbeams //------------------------------------------------------------------------------ 113*7d8d0e25Snbeams inline __device__ void interp1d(const CeedInt nelem, const int transpose, 114*7d8d0e25Snbeams const CeedScalar *c_B, 115*7d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 116*7d8d0e25Snbeams CeedScalar *__restrict__ d_V, 117*7d8d0e25Snbeams CeedScalar *slice) { 118*7d8d0e25Snbeams CeedScalar r_V; 119*7d8d0e25Snbeams CeedScalar r_t; 120*7d8d0e25Snbeams 121*7d8d0e25Snbeams const int tidx = threadIdx.x; 122*7d8d0e25Snbeams const int tidy = threadIdx.y; 123*7d8d0e25Snbeams const int tidz = threadIdx.z; 124*7d8d0e25Snbeams 125*7d8d0e25Snbeams 126*7d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 127*7d8d0e25Snbeams elem += gridDim.x*blockDim.z) { 128*7d8d0e25Snbeams for (int comp = 0; comp < BASIS_NCOMP; comp++) { 129*7d8d0e25Snbeams if (!transpose) { 130*7d8d0e25Snbeams readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 131*7d8d0e25Snbeams ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 132*7d8d0e25Snbeams writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 133*7d8d0e25Snbeams } else { 134*7d8d0e25Snbeams readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice); 135*7d8d0e25Snbeams ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 136*7d8d0e25Snbeams writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 137*7d8d0e25Snbeams } 138*7d8d0e25Snbeams } 139*7d8d0e25Snbeams } 140*7d8d0e25Snbeams } 141*7d8d0e25Snbeams 142*7d8d0e25Snbeams //------------------------------------------------------------------------------ 143*7d8d0e25Snbeams // 1D derivatives at quadrature points 144*7d8d0e25Snbeams //------------------------------------------------------------------------------ 145*7d8d0e25Snbeams inline __device__ void grad1d(const CeedInt nelem, const int transpose, 146*7d8d0e25Snbeams const CeedScalar *c_B, const CeedScalar *c_G, 147*7d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 148*7d8d0e25Snbeams CeedScalar *__restrict__ d_V, 149*7d8d0e25Snbeams CeedScalar *slice) { 150*7d8d0e25Snbeams CeedScalar r_U; 151*7d8d0e25Snbeams CeedScalar r_V; 152*7d8d0e25Snbeams 153*7d8d0e25Snbeams const int tidx = threadIdx.x; 154*7d8d0e25Snbeams const int tidy = threadIdx.y; 155*7d8d0e25Snbeams const int tidz = threadIdx.z; 156*7d8d0e25Snbeams int dim; 157*7d8d0e25Snbeams 158*7d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 159*7d8d0e25Snbeams elem += gridDim.x*blockDim.z) { 160*7d8d0e25Snbeams for(int comp = 0; comp < BASIS_NCOMP; comp++) { 161*7d8d0e25Snbeams if (!transpose) { 162*7d8d0e25Snbeams readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 163*7d8d0e25Snbeams ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 164*7d8d0e25Snbeams dim = 0; 165*7d8d0e25Snbeams writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 166*7d8d0e25Snbeams } else { 167*7d8d0e25Snbeams dim = 0; 168*7d8d0e25Snbeams readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice); 169*7d8d0e25Snbeams ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 170*7d8d0e25Snbeams writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 171*7d8d0e25Snbeams } 172*7d8d0e25Snbeams } 173*7d8d0e25Snbeams } 174*7d8d0e25Snbeams } 175*7d8d0e25Snbeams 176*7d8d0e25Snbeams //------------------------------------------------------------------------------ 177*7d8d0e25Snbeams // 1D Quadrature weights 178*7d8d0e25Snbeams //------------------------------------------------------------------------------ 179*7d8d0e25Snbeams __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d, 180*7d8d0e25Snbeams CeedScalar *w) { 181*7d8d0e25Snbeams const int tid = threadIdx.x; 182*7d8d0e25Snbeams const CeedScalar weight = qweight1d[tid]; 183*7d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem; 184*7d8d0e25Snbeams elem += gridDim.x*blockDim.y) { 185*7d8d0e25Snbeams const int ind = elem*Q1D + tid; 186*7d8d0e25Snbeams w[ind] = weight; 187*7d8d0e25Snbeams } 188*7d8d0e25Snbeams } 189*7d8d0e25Snbeams 190*7d8d0e25Snbeams //------------------------------------------------------------------------------ 191*7d8d0e25Snbeams // 2D 192*7d8d0e25Snbeams //------------------------------------------------------------------------------ 193*7d8d0e25Snbeams 194*7d8d0e25Snbeams //------------------------------------------------------------------------------ 195*7d8d0e25Snbeams // Read DoFs 196*7d8d0e25Snbeams //------------------------------------------------------------------------------ 197*7d8d0e25Snbeams inline __device__ void readDofs2d(const int elem, const int tidx, 198*7d8d0e25Snbeams const int tidy, const int comp, 199*7d8d0e25Snbeams const int nelem, const CeedScalar *d_U, 200*7d8d0e25Snbeams CeedScalar &U) { 201*7d8d0e25Snbeams U = (tidx<P1D && tidy<P1D) ? 202*7d8d0e25Snbeams d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0; 203*7d8d0e25Snbeams } 204*7d8d0e25Snbeams 205*7d8d0e25Snbeams //------------------------------------------------------------------------------ 206*7d8d0e25Snbeams // Write DoFs 207*7d8d0e25Snbeams //------------------------------------------------------------------------------ 208*7d8d0e25Snbeams inline __device__ void writeDofs2d(const int elem, const int tidx, 209*7d8d0e25Snbeams const int tidy, const int comp, 210*7d8d0e25Snbeams const int nelem, const CeedScalar &r_V, 211*7d8d0e25Snbeams CeedScalar *d_V) { 212*7d8d0e25Snbeams if (tidx<P1D && tidy<P1D) 213*7d8d0e25Snbeams d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V; 214*7d8d0e25Snbeams } 215*7d8d0e25Snbeams 216*7d8d0e25Snbeams //------------------------------------------------------------------------------ 217*7d8d0e25Snbeams // Read quadrature point data 218*7d8d0e25Snbeams //------------------------------------------------------------------------------ 219*7d8d0e25Snbeams inline __device__ void readQuads2d(const int elem, const int tidx, 220*7d8d0e25Snbeams const int tidy, const int comp, 221*7d8d0e25Snbeams const int dim, const int nelem, 222*7d8d0e25Snbeams const CeedScalar *d_U, CeedScalar &U ) { 223*7d8d0e25Snbeams U = (tidx<Q1D && tidy<Q1D) ? 224*7d8d0e25Snbeams d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 225*7d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0; 226*7d8d0e25Snbeams } 227*7d8d0e25Snbeams 228*7d8d0e25Snbeams //------------------------------------------------------------------------------ 229*7d8d0e25Snbeams // Write quadrature point data 230*7d8d0e25Snbeams //------------------------------------------------------------------------------ 231*7d8d0e25Snbeams inline __device__ void writeQuads2d(const int elem, const int tidx, 232*7d8d0e25Snbeams const int tidy, const int comp, 233*7d8d0e25Snbeams const int dim, const int nelem, 234*7d8d0e25Snbeams const CeedScalar &r_V, CeedScalar *d_V) { 235*7d8d0e25Snbeams if (tidx<Q1D && tidy<Q1D) 236*7d8d0e25Snbeams d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 237*7d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V; 238*7d8d0e25Snbeams } 239*7d8d0e25Snbeams 240*7d8d0e25Snbeams //------------------------------------------------------------------------------ 241*7d8d0e25Snbeams // 2D tensor contraction x 242*7d8d0e25Snbeams //------------------------------------------------------------------------------ 243*7d8d0e25Snbeams inline __device__ void ContractX2d(CeedScalar *slice, const int tidx, 244*7d8d0e25Snbeams const int tidy, const int tidz, 245*7d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, 246*7d8d0e25Snbeams CeedScalar &V) { 247*7d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 248*7d8d0e25Snbeams __syncthreads(); 249*7d8d0e25Snbeams V = 0.0; 250*7d8d0e25Snbeams if (tidx < Q1D) 251*7d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 252*7d8d0e25Snbeams V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 253*7d8d0e25Snbeams __syncthreads(); 254*7d8d0e25Snbeams } 255*7d8d0e25Snbeams 256*7d8d0e25Snbeams //------------------------------------------------------------------------------ 257*7d8d0e25Snbeams // 2D tensor contraction y 258*7d8d0e25Snbeams //------------------------------------------------------------------------------ 259*7d8d0e25Snbeams inline __device__ void ContractY2d(CeedScalar *slice, const int tidx, 260*7d8d0e25Snbeams const int tidy, const int tidz, 261*7d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, 262*7d8d0e25Snbeams CeedScalar &V) { 263*7d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 264*7d8d0e25Snbeams __syncthreads(); 265*7d8d0e25Snbeams V = 0.0; 266*7d8d0e25Snbeams if (tidy < Q1D) 267*7d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 268*7d8d0e25Snbeams V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 269*7d8d0e25Snbeams __syncthreads(); 270*7d8d0e25Snbeams } 271*7d8d0e25Snbeams 272*7d8d0e25Snbeams //------------------------------------------------------------------------------ 273*7d8d0e25Snbeams // 2D transpose tensor contraction y 274*7d8d0e25Snbeams //------------------------------------------------------------------------------ 275*7d8d0e25Snbeams inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx, 276*7d8d0e25Snbeams const int tidy, const int tidz, 277*7d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 278*7d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 279*7d8d0e25Snbeams __syncthreads(); 280*7d8d0e25Snbeams V = 0.0; 281*7d8d0e25Snbeams if (tidy < P1D) 282*7d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 283*7d8d0e25Snbeams V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 284*7d8d0e25Snbeams __syncthreads(); 285*7d8d0e25Snbeams } 286*7d8d0e25Snbeams 287*7d8d0e25Snbeams //------------------------------------------------------------------------------ 288*7d8d0e25Snbeams // 2D transpose tensor contraction x 289*7d8d0e25Snbeams //------------------------------------------------------------------------------ 290*7d8d0e25Snbeams inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx, 291*7d8d0e25Snbeams const int tidy, const int tidz, 292*7d8d0e25Snbeams const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 293*7d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U; 294*7d8d0e25Snbeams __syncthreads(); 295*7d8d0e25Snbeams V = 0.0; 296*7d8d0e25Snbeams if (tidx < P1D) 297*7d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 298*7d8d0e25Snbeams V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 299*7d8d0e25Snbeams __syncthreads(); 300*7d8d0e25Snbeams } 301*7d8d0e25Snbeams 302*7d8d0e25Snbeams //------------------------------------------------------------------------------ 303*7d8d0e25Snbeams // 2D interpolate to quadrature points 304*7d8d0e25Snbeams //------------------------------------------------------------------------------ 305*7d8d0e25Snbeams inline __device__ void interp2d(const CeedInt nelem, const int transpose, 306*7d8d0e25Snbeams const CeedScalar *c_B, 307*7d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 308*7d8d0e25Snbeams CeedScalar *__restrict__ d_V, 309*7d8d0e25Snbeams CeedScalar *slice) { 310*7d8d0e25Snbeams CeedScalar r_V; 311*7d8d0e25Snbeams CeedScalar r_t; 312*7d8d0e25Snbeams 313*7d8d0e25Snbeams const int tidx = threadIdx.x; 314*7d8d0e25Snbeams const int tidy = threadIdx.y; 315*7d8d0e25Snbeams const int tidz = threadIdx.z; 316*7d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 317*7d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 318*7d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 319*7d8d0e25Snbeams 320*7d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 321*7d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 322*7d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 323*7d8d0e25Snbeams r_V = 0.0; 324*7d8d0e25Snbeams r_t = 0.0; 325*7d8d0e25Snbeams if (!transpose) { 326*7d8d0e25Snbeams readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V); 327*7d8d0e25Snbeams ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 328*7d8d0e25Snbeams ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 329*7d8d0e25Snbeams writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 330*7d8d0e25Snbeams } else { 331*7d8d0e25Snbeams readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 332*7d8d0e25Snbeams ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 333*7d8d0e25Snbeams ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 334*7d8d0e25Snbeams writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 335*7d8d0e25Snbeams } 336*7d8d0e25Snbeams } 337*7d8d0e25Snbeams } 338*7d8d0e25Snbeams 339*7d8d0e25Snbeams //------------------------------------------------------------------------------ 340*7d8d0e25Snbeams // 2D derivatives at quadrature points 341*7d8d0e25Snbeams //------------------------------------------------------------------------------ 342*7d8d0e25Snbeams inline __device__ void grad2d(const CeedInt nelem, const int transpose, 343*7d8d0e25Snbeams const CeedScalar *c_B, const CeedScalar *c_G, 344*7d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 345*7d8d0e25Snbeams CeedScalar *__restrict__ d_V, CeedScalar *slice) { 346*7d8d0e25Snbeams CeedScalar r_U; 347*7d8d0e25Snbeams CeedScalar r_V; 348*7d8d0e25Snbeams CeedScalar r_t; 349*7d8d0e25Snbeams 350*7d8d0e25Snbeams const int tidx = threadIdx.x; 351*7d8d0e25Snbeams const int tidy = threadIdx.y; 352*7d8d0e25Snbeams const int tidz = threadIdx.z; 353*7d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 354*7d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 355*7d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 356*7d8d0e25Snbeams int dim; 357*7d8d0e25Snbeams 358*7d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 359*7d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 360*7d8d0e25Snbeams if (!transpose) { 361*7d8d0e25Snbeams readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U); 362*7d8d0e25Snbeams ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 363*7d8d0e25Snbeams ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 364*7d8d0e25Snbeams dim = 0; 365*7d8d0e25Snbeams writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 366*7d8d0e25Snbeams ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 367*7d8d0e25Snbeams ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 368*7d8d0e25Snbeams dim = 1; 369*7d8d0e25Snbeams writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 370*7d8d0e25Snbeams } else { 371*7d8d0e25Snbeams dim = 0; 372*7d8d0e25Snbeams readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 373*7d8d0e25Snbeams ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 374*7d8d0e25Snbeams ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 375*7d8d0e25Snbeams dim = 1; 376*7d8d0e25Snbeams readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 377*7d8d0e25Snbeams ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 378*7d8d0e25Snbeams ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 379*7d8d0e25Snbeams r_V += r_U; 380*7d8d0e25Snbeams writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 381*7d8d0e25Snbeams } 382*7d8d0e25Snbeams } 383*7d8d0e25Snbeams } 384*7d8d0e25Snbeams 385*7d8d0e25Snbeams //------------------------------------------------------------------------------ 386*7d8d0e25Snbeams // 2D quadrature weights 387*7d8d0e25Snbeams //------------------------------------------------------------------------------ 388*7d8d0e25Snbeams __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d, 389*7d8d0e25Snbeams CeedScalar *w) { 390*7d8d0e25Snbeams const int i = threadIdx.x; 391*7d8d0e25Snbeams const int j = threadIdx.y; 392*7d8d0e25Snbeams const CeedScalar weight = qweight1d[i]*qweight1d[j]; 393*7d8d0e25Snbeams for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 394*7d8d0e25Snbeams elem += gridDim.x*blockDim.z) { 395*7d8d0e25Snbeams const int ind = elem*Q1D*Q1D + i + j*Q1D; 396*7d8d0e25Snbeams w[ind] = weight; 397*7d8d0e25Snbeams } 398*7d8d0e25Snbeams } 399*7d8d0e25Snbeams 400*7d8d0e25Snbeams //------------------------------------------------------------------------------ 401*7d8d0e25Snbeams // 3D 402*7d8d0e25Snbeams //------------------------------------------------------------------------------ 403*7d8d0e25Snbeams 404*7d8d0e25Snbeams //------------------------------------------------------------------------------ 405*7d8d0e25Snbeams // Read DoFs 406*7d8d0e25Snbeams //------------------------------------------------------------------------------ 407*7d8d0e25Snbeams inline __device__ void readDofs3d(const int elem, const int tidx, 408*7d8d0e25Snbeams const int tidy, const int comp, 409*7d8d0e25Snbeams const int nelem, const CeedScalar *d_U, 410*7d8d0e25Snbeams CeedScalar *r_U) { 411*7d8d0e25Snbeams for (int i = 0; i < P1D; i++) 412*7d8d0e25Snbeams r_U[i] = (tidx < P1D && tidy < P1D) ? 413*7d8d0e25Snbeams d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D + 414*7d8d0e25Snbeams comp*P1D*P1D*P1D*nelem] : 0.0; 415*7d8d0e25Snbeams for (int i = P1D; i < Q1D; i++) 416*7d8d0e25Snbeams r_U[i] = 0.0; 417*7d8d0e25Snbeams } 418*7d8d0e25Snbeams 419*7d8d0e25Snbeams //------------------------------------------------------------------------------ 420*7d8d0e25Snbeams // Write DoFs 421*7d8d0e25Snbeams //------------------------------------------------------------------------------ 422*7d8d0e25Snbeams inline __device__ void writeDofs3d(const int elem, const int tidx, 423*7d8d0e25Snbeams const int tidy, const int comp, 424*7d8d0e25Snbeams const int nelem, const CeedScalar *r_V, 425*7d8d0e25Snbeams CeedScalar *d_V) { 426*7d8d0e25Snbeams if (tidx < P1D && tidy < P1D) { 427*7d8d0e25Snbeams for (int i = 0; i < P1D; i++) 428*7d8d0e25Snbeams d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D + 429*7d8d0e25Snbeams comp*P1D*P1D*P1D*nelem] = r_V[i]; 430*7d8d0e25Snbeams } 431*7d8d0e25Snbeams } 432*7d8d0e25Snbeams 433*7d8d0e25Snbeams //------------------------------------------------------------------------------ 434*7d8d0e25Snbeams // Read quadrature point data 435*7d8d0e25Snbeams //------------------------------------------------------------------------------ 436*7d8d0e25Snbeams inline __device__ void readQuads3d(const int elem, const int tidx, 437*7d8d0e25Snbeams const int tidy, const int comp, 438*7d8d0e25Snbeams const int dim, const int nelem, 439*7d8d0e25Snbeams const CeedScalar *d_U, CeedScalar *r_U) { 440*7d8d0e25Snbeams for (int i = 0; i < Q1D; i++) 441*7d8d0e25Snbeams r_U[i] = (tidx < Q1D && tidy < Q1D) ? 442*7d8d0e25Snbeams d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + 443*7d8d0e25Snbeams comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0; 444*7d8d0e25Snbeams for (int i = Q1D; i < P1D; i++) 445*7d8d0e25Snbeams r_U[i] = 0.0; 446*7d8d0e25Snbeams } 447*7d8d0e25Snbeams 448*7d8d0e25Snbeams //------------------------------------------------------------------------------ 449*7d8d0e25Snbeams // Write quadrature point data 450*7d8d0e25Snbeams //------------------------------------------------------------------------------ 451*7d8d0e25Snbeams inline __device__ void writeQuads3d(const int elem, const int tidx, 452*7d8d0e25Snbeams const int tidy, const int comp, 453*7d8d0e25Snbeams const int dim, const int nelem, 454*7d8d0e25Snbeams const CeedScalar *r_V, CeedScalar *d_V) { 455*7d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) { 456*7d8d0e25Snbeams for (int i = 0; i < Q1D; i++) 457*7d8d0e25Snbeams d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem + 458*7d8d0e25Snbeams dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i]; 459*7d8d0e25Snbeams } 460*7d8d0e25Snbeams } 461*7d8d0e25Snbeams 462*7d8d0e25Snbeams //------------------------------------------------------------------------------ 463*7d8d0e25Snbeams // 3D tensor contract x 464*7d8d0e25Snbeams //------------------------------------------------------------------------------ 465*7d8d0e25Snbeams inline __device__ void ContractX3d(CeedScalar *slice, const int tidx, 466*7d8d0e25Snbeams const int tidy, const int tidz, 467*7d8d0e25Snbeams const CeedScalar *U, 468*7d8d0e25Snbeams const CeedScalar *B, 469*7d8d0e25Snbeams CeedScalar *V) { 470*7d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 471*7d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 472*7d8d0e25Snbeams __syncthreads(); 473*7d8d0e25Snbeams V[k] = 0.0; 474*7d8d0e25Snbeams if (tidx < Q1D && tidy < P1D) 475*7d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 476*7d8d0e25Snbeams V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 477*7d8d0e25Snbeams __syncthreads(); 478*7d8d0e25Snbeams } 479*7d8d0e25Snbeams } 480*7d8d0e25Snbeams 481*7d8d0e25Snbeams //------------------------------------------------------------------------------ 482*7d8d0e25Snbeams // 3D tensor contract y 483*7d8d0e25Snbeams //------------------------------------------------------------------------------ 484*7d8d0e25Snbeams inline __device__ void ContractY3d(CeedScalar *slice, const int tidx, 485*7d8d0e25Snbeams const int tidy, const int tidz, 486*7d8d0e25Snbeams const CeedScalar *U, 487*7d8d0e25Snbeams const CeedScalar *B, 488*7d8d0e25Snbeams CeedScalar *V) { 489*7d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 490*7d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 491*7d8d0e25Snbeams __syncthreads(); 492*7d8d0e25Snbeams V[k] = 0.0; 493*7d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) 494*7d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 495*7d8d0e25Snbeams V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 496*7d8d0e25Snbeams __syncthreads(); 497*7d8d0e25Snbeams } 498*7d8d0e25Snbeams } 499*7d8d0e25Snbeams 500*7d8d0e25Snbeams //------------------------------------------------------------------------------ 501*7d8d0e25Snbeams // 3D tensor contract z 502*7d8d0e25Snbeams //------------------------------------------------------------------------------ 503*7d8d0e25Snbeams inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx, 504*7d8d0e25Snbeams const int tidy, const int tidz, 505*7d8d0e25Snbeams const CeedScalar *U, 506*7d8d0e25Snbeams const CeedScalar *B, 507*7d8d0e25Snbeams CeedScalar *V) { 508*7d8d0e25Snbeams for (int k = 0; k < Q1D; ++k) { 509*7d8d0e25Snbeams V[k] = 0.0; 510*7d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) 511*7d8d0e25Snbeams for (int i = 0; i < P1D; ++i) 512*7d8d0e25Snbeams V[k] += B[i + k*P1D] * U[i]; // Contract z direction 513*7d8d0e25Snbeams } 514*7d8d0e25Snbeams for (int k = Q1D; k < P1D; ++k) 515*7d8d0e25Snbeams V[k] = 0.0; 516*7d8d0e25Snbeams } 517*7d8d0e25Snbeams 518*7d8d0e25Snbeams //------------------------------------------------------------------------------ 519*7d8d0e25Snbeams // 3D transpose tensor contract z 520*7d8d0e25Snbeams //------------------------------------------------------------------------------ 521*7d8d0e25Snbeams inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx, 522*7d8d0e25Snbeams const int tidy, const int tidz, 523*7d8d0e25Snbeams const CeedScalar *U, 524*7d8d0e25Snbeams const CeedScalar *B, 525*7d8d0e25Snbeams CeedScalar *V) { 526*7d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 527*7d8d0e25Snbeams V[k] = 0.0; 528*7d8d0e25Snbeams if (tidx < Q1D && tidy < Q1D) 529*7d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 530*7d8d0e25Snbeams V[k] += B[k + i*P1D] * U[i]; // Contract z direction 531*7d8d0e25Snbeams } 532*7d8d0e25Snbeams for (int k = P1D; k < Q1D; ++k) 533*7d8d0e25Snbeams V[k] = 0.0; 534*7d8d0e25Snbeams } 535*7d8d0e25Snbeams 536*7d8d0e25Snbeams //------------------------------------------------------------------------------ 537*7d8d0e25Snbeams // 3D transpose tensor contract y 538*7d8d0e25Snbeams //------------------------------------------------------------------------------ 539*7d8d0e25Snbeams inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx, 540*7d8d0e25Snbeams const int tidy, const int tidz, 541*7d8d0e25Snbeams const CeedScalar *U, 542*7d8d0e25Snbeams const CeedScalar *B, 543*7d8d0e25Snbeams CeedScalar *V) { 544*7d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 545*7d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 546*7d8d0e25Snbeams __syncthreads(); 547*7d8d0e25Snbeams V[k] = 0.0; 548*7d8d0e25Snbeams if (tidx < Q1D && tidy < P1D) 549*7d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 550*7d8d0e25Snbeams V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction 551*7d8d0e25Snbeams __syncthreads(); 552*7d8d0e25Snbeams } 553*7d8d0e25Snbeams } 554*7d8d0e25Snbeams 555*7d8d0e25Snbeams //------------------------------------------------------------------------------ 556*7d8d0e25Snbeams // 3D transpose tensor contract x 557*7d8d0e25Snbeams //------------------------------------------------------------------------------ 558*7d8d0e25Snbeams inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx, 559*7d8d0e25Snbeams const int tidy, const int tidz, 560*7d8d0e25Snbeams const CeedScalar *U, 561*7d8d0e25Snbeams const CeedScalar *B, 562*7d8d0e25Snbeams CeedScalar *V) { 563*7d8d0e25Snbeams for (int k = 0; k < P1D; ++k) { 564*7d8d0e25Snbeams slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k]; 565*7d8d0e25Snbeams __syncthreads(); 566*7d8d0e25Snbeams V[k] = 0.0; 567*7d8d0e25Snbeams if (tidx < P1D && tidy < P1D) 568*7d8d0e25Snbeams for (int i = 0; i < Q1D; ++i) 569*7d8d0e25Snbeams V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction 570*7d8d0e25Snbeams __syncthreads(); 571*7d8d0e25Snbeams } 572*7d8d0e25Snbeams } 573*7d8d0e25Snbeams 574*7d8d0e25Snbeams //------------------------------------------------------------------------------ 575*7d8d0e25Snbeams // 3D interpolate to quadrature points 576*7d8d0e25Snbeams //------------------------------------------------------------------------------ 577*7d8d0e25Snbeams inline __device__ void interp3d(const CeedInt nelem, const int transpose, 578*7d8d0e25Snbeams const CeedScalar *c_B, 579*7d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 580*7d8d0e25Snbeams CeedScalar *__restrict__ d_V, 581*7d8d0e25Snbeams CeedScalar *slice) { 582*7d8d0e25Snbeams CeedScalar r_V[T1D]; 583*7d8d0e25Snbeams CeedScalar r_t[T1D]; 584*7d8d0e25Snbeams 585*7d8d0e25Snbeams const int tidx = threadIdx.x; 586*7d8d0e25Snbeams const int tidy = threadIdx.y; 587*7d8d0e25Snbeams const int tidz = threadIdx.z; 588*7d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 589*7d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 590*7d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 591*7d8d0e25Snbeams 592*7d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 593*7d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 594*7d8d0e25Snbeams for (int i = 0; i < T1D; ++i) { 595*7d8d0e25Snbeams r_V[i] = 0.0; 596*7d8d0e25Snbeams r_t[i] = 0.0; 597*7d8d0e25Snbeams } 598*7d8d0e25Snbeams if (!transpose) { 599*7d8d0e25Snbeams readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V); 600*7d8d0e25Snbeams ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 601*7d8d0e25Snbeams ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 602*7d8d0e25Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 603*7d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V); 604*7d8d0e25Snbeams } else { 605*7d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 606*7d8d0e25Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 607*7d8d0e25Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 608*7d8d0e25Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 609*7d8d0e25Snbeams writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V); 610*7d8d0e25Snbeams } 611*7d8d0e25Snbeams } 612*7d8d0e25Snbeams } 613*7d8d0e25Snbeams 614*7d8d0e25Snbeams //------------------------------------------------------------------------------ 615*7d8d0e25Snbeams // 3D derivatives at quadrature points 616*7d8d0e25Snbeams //------------------------------------------------------------------------------ 617*7d8d0e25Snbeams inline __device__ void grad3d(const CeedInt nelem, const int transpose, 618*7d8d0e25Snbeams const CeedScalar *c_B, const CeedScalar *c_G, 619*7d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 620*7d8d0e25Snbeams CeedScalar *__restrict__ d_V, 621*7d8d0e25Snbeams CeedScalar *slice) { 622*7d8d0e25Snbeams // Use P1D for one of these 623*7d8d0e25Snbeams CeedScalar r_U[T1D]; 624*7d8d0e25Snbeams CeedScalar r_V[T1D]; 625*7d8d0e25Snbeams CeedScalar r_t[T1D]; 626*7d8d0e25Snbeams 627*7d8d0e25Snbeams const int tidx = threadIdx.x; 628*7d8d0e25Snbeams const int tidy = threadIdx.y; 629*7d8d0e25Snbeams const int tidz = threadIdx.z; 630*7d8d0e25Snbeams const int blockElem = tidz/BASIS_NCOMP; 631*7d8d0e25Snbeams const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 632*7d8d0e25Snbeams const int comp = tidz%BASIS_NCOMP; 633*7d8d0e25Snbeams int dim; 634*7d8d0e25Snbeams 635*7d8d0e25Snbeams for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 636*7d8d0e25Snbeams elem += gridDim.x*elemsPerBlock) { 637*7d8d0e25Snbeams for (int i = 0; i < T1D; ++i) { 638*7d8d0e25Snbeams r_U[i] = 0.0; 639*7d8d0e25Snbeams r_V[i] = 0.0; 640*7d8d0e25Snbeams r_t[i] = 0.0; 641*7d8d0e25Snbeams } 642*7d8d0e25Snbeams if (!transpose) { 643*7d8d0e25Snbeams readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U); 644*7d8d0e25Snbeams ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 645*7d8d0e25Snbeams ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 646*7d8d0e25Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 647*7d8d0e25Snbeams dim = 0; 648*7d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 649*7d8d0e25Snbeams ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V); 650*7d8d0e25Snbeams ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t); 651*7d8d0e25Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 652*7d8d0e25Snbeams dim = 1; 653*7d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 654*7d8d0e25Snbeams ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V); 655*7d8d0e25Snbeams ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 656*7d8d0e25Snbeams ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 657*7d8d0e25Snbeams dim = 2; 658*7d8d0e25Snbeams writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 659*7d8d0e25Snbeams } else { 660*7d8d0e25Snbeams dim = 0; 661*7d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 662*7d8d0e25Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 663*7d8d0e25Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 664*7d8d0e25Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 665*7d8d0e25Snbeams dim = 1; 666*7d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 667*7d8d0e25Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 668*7d8d0e25Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U); 669*7d8d0e25Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 670*7d8d0e25Snbeams add(r_V, r_t); 671*7d8d0e25Snbeams dim = 2; 672*7d8d0e25Snbeams readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 673*7d8d0e25Snbeams ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 674*7d8d0e25Snbeams ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 675*7d8d0e25Snbeams ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 676*7d8d0e25Snbeams add(r_V, r_t); 677*7d8d0e25Snbeams writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V); 678*7d8d0e25Snbeams } 679*7d8d0e25Snbeams } 680*7d8d0e25Snbeams } 681*7d8d0e25Snbeams 682*7d8d0e25Snbeams //------------------------------------------------------------------------------ 683*7d8d0e25Snbeams // 3D quadrature weights 684*7d8d0e25Snbeams //------------------------------------------------------------------------------ 685*7d8d0e25Snbeams __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d, 686*7d8d0e25Snbeams CeedScalar *w) { 687*7d8d0e25Snbeams const int i = threadIdx.x; 688*7d8d0e25Snbeams const int j = threadIdx.y; 689*7d8d0e25Snbeams const int k = threadIdx.z; 690*7d8d0e25Snbeams const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k]; 691*7d8d0e25Snbeams for (int e = blockIdx.x; e < nelem; e += gridDim.x) { 692*7d8d0e25Snbeams const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D; 693*7d8d0e25Snbeams w[ind] = weight; 694*7d8d0e25Snbeams } 695*7d8d0e25Snbeams } 696*7d8d0e25Snbeams 697*7d8d0e25Snbeams 698*7d8d0e25Snbeams //------------------------------------------------------------------------------ 699*7d8d0e25Snbeams // Basis kernels 700*7d8d0e25Snbeams //------------------------------------------------------------------------------ 701*7d8d0e25Snbeams 702*7d8d0e25Snbeams //------------------------------------------------------------------------------ 703*7d8d0e25Snbeams // Interp kernel by dim 704*7d8d0e25Snbeams //------------------------------------------------------------------------------ 705*7d8d0e25Snbeams extern "C" __global__ void interp(const CeedInt nelem, const int transpose, 706*7d8d0e25Snbeams const CeedScalar *c_B, 707*7d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 708*7d8d0e25Snbeams CeedScalar *__restrict__ d_V) { 709*7d8d0e25Snbeams HIP_DYNAMIC_SHARED( double, slice) 710*7d8d0e25Snbeams if (BASIS_DIM == 1) { 711*7d8d0e25Snbeams interp1d(nelem, transpose, c_B, d_U, d_V, slice); 712*7d8d0e25Snbeams } else if (BASIS_DIM == 2) { 713*7d8d0e25Snbeams interp2d(nelem, transpose, c_B, d_U, d_V, slice); 714*7d8d0e25Snbeams } else if (BASIS_DIM == 3) { 715*7d8d0e25Snbeams interp3d(nelem, transpose, c_B, d_U, d_V, slice); 716*7d8d0e25Snbeams } 717*7d8d0e25Snbeams } 718*7d8d0e25Snbeams 719*7d8d0e25Snbeams //------------------------------------------------------------------------------ 720*7d8d0e25Snbeams // Grad kernel by dim 721*7d8d0e25Snbeams //------------------------------------------------------------------------------ 722*7d8d0e25Snbeams extern "C" __global__ void grad(const CeedInt nelem, const int transpose, 723*7d8d0e25Snbeams const CeedScalar *c_B, const CeedScalar *c_G, 724*7d8d0e25Snbeams const CeedScalar *__restrict__ d_U, 725*7d8d0e25Snbeams CeedScalar *__restrict__ d_V) { 726*7d8d0e25Snbeams HIP_DYNAMIC_SHARED( double, slice) 727*7d8d0e25Snbeams if (BASIS_DIM == 1) { 728*7d8d0e25Snbeams grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 729*7d8d0e25Snbeams } else if (BASIS_DIM == 2) { 730*7d8d0e25Snbeams grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 731*7d8d0e25Snbeams } else if (BASIS_DIM == 3) { 732*7d8d0e25Snbeams grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 733*7d8d0e25Snbeams } 734*7d8d0e25Snbeams } 735*7d8d0e25Snbeams 736*7d8d0e25Snbeams //------------------------------------------------------------------------------ 737*7d8d0e25Snbeams // Weight kernels by dim 738*7d8d0e25Snbeams //------------------------------------------------------------------------------ 739*7d8d0e25Snbeams extern "C" __global__ void weight(const CeedInt nelem, 740*7d8d0e25Snbeams const CeedScalar *__restrict__ qweight1d, 741*7d8d0e25Snbeams CeedScalar *__restrict__ v) { 742*7d8d0e25Snbeams if (BASIS_DIM == 1) { 743*7d8d0e25Snbeams weight1d(nelem, qweight1d, v); 744*7d8d0e25Snbeams } else if (BASIS_DIM == 2) { 745*7d8d0e25Snbeams weight2d(nelem, qweight1d, v); 746*7d8d0e25Snbeams } else if (BASIS_DIM == 3) { 747*7d8d0e25Snbeams weight3d(nelem, qweight1d, v); 748*7d8d0e25Snbeams } 749*7d8d0e25Snbeams } 750*7d8d0e25Snbeams 751*7d8d0e25Snbeams ); 752*7d8d0e25Snbeams // *INDENT-ON* 753*7d8d0e25Snbeams 754*7d8d0e25Snbeams //------------------------------------------------------------------------------ 755*7d8d0e25Snbeams // Device initalization 756*7d8d0e25Snbeams //------------------------------------------------------------------------------ 757*7d8d0e25Snbeams int CeedHipInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d, 758*7d8d0e25Snbeams CeedScalar **c_B); 759*7d8d0e25Snbeams int CeedHipInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d, 760*7d8d0e25Snbeams CeedInt Q1d, CeedScalar **c_B_ptr, 761*7d8d0e25Snbeams CeedScalar **c_G_ptr); 762*7d8d0e25Snbeams 763*7d8d0e25Snbeams //------------------------------------------------------------------------------ 764*7d8d0e25Snbeams // Apply basis 765*7d8d0e25Snbeams //------------------------------------------------------------------------------ 766*7d8d0e25Snbeams int CeedBasisApplyTensor_Hip_shared(CeedBasis basis, const CeedInt nelem, 767*7d8d0e25Snbeams CeedTransposeMode tmode, 768*7d8d0e25Snbeams CeedEvalMode emode, CeedVector u, 769*7d8d0e25Snbeams CeedVector v) { 770*7d8d0e25Snbeams int ierr; 771*7d8d0e25Snbeams Ceed ceed; 772*7d8d0e25Snbeams ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr); 773*7d8d0e25Snbeams Ceed_Hip_shared *ceed_Hip; 774*7d8d0e25Snbeams CeedGetData(ceed, &ceed_Hip); CeedChk(ierr); 775*7d8d0e25Snbeams CeedBasis_Hip_shared *data; 776*7d8d0e25Snbeams CeedBasisGetData(basis, &data); CeedChk(ierr); 777*7d8d0e25Snbeams const CeedInt transpose = tmode == CEED_TRANSPOSE; 778*7d8d0e25Snbeams CeedInt dim, ncomp; 779*7d8d0e25Snbeams ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr); 780*7d8d0e25Snbeams ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr); 781*7d8d0e25Snbeams 782*7d8d0e25Snbeams // Read vectors 783*7d8d0e25Snbeams const CeedScalar *d_u; 784*7d8d0e25Snbeams CeedScalar *d_v; 785*7d8d0e25Snbeams if (emode != CEED_EVAL_WEIGHT) { 786*7d8d0e25Snbeams ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr); 787*7d8d0e25Snbeams } 788*7d8d0e25Snbeams ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr); 789*7d8d0e25Snbeams 790*7d8d0e25Snbeams // Clear v for transpose mode 791*7d8d0e25Snbeams if (tmode == CEED_TRANSPOSE) { 792*7d8d0e25Snbeams CeedInt length; 793*7d8d0e25Snbeams ierr = CeedVectorGetLength(v, &length); CeedChk(ierr); 794*7d8d0e25Snbeams ierr = hipMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr); 795*7d8d0e25Snbeams } 796*7d8d0e25Snbeams 797*7d8d0e25Snbeams // Apply basis operation 798*7d8d0e25Snbeams switch (emode) { 799*7d8d0e25Snbeams case CEED_EVAL_INTERP: { 800*7d8d0e25Snbeams CeedInt P1d, Q1d; 801*7d8d0e25Snbeams ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr); 802*7d8d0e25Snbeams ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr); 803*7d8d0e25Snbeams CeedInt thread1d = CeedIntMax(Q1d, P1d); 804*7d8d0e25Snbeams ierr = CeedHipInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B); 805*7d8d0e25Snbeams CeedChk(ierr); 806*7d8d0e25Snbeams void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, 807*7d8d0e25Snbeams &d_u, &d_v 808*7d8d0e25Snbeams }; 809*7d8d0e25Snbeams if (dim == 1) { 810*7d8d0e25Snbeams CeedInt elemsPerBlock = 32*thread1d > 256? 256/thread1d : 32; 811*7d8d0e25Snbeams elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1; 812*7d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 813*7d8d0e25Snbeams ? 1 : 0 ); 814*7d8d0e25Snbeams CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar); 815*7d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, 1, 816*7d8d0e25Snbeams elemsPerBlock, sharedMem, 817*7d8d0e25Snbeams interpargs); CeedChk(ierr); 818*7d8d0e25Snbeams } else if (dim == 2) { 819*7d8d0e25Snbeams const CeedInt optElems[7] = {0,32,8,6,4,2,6}; 820*7d8d0e25Snbeams // elemsPerBlock must be at least 1 821*7d8d0e25Snbeams CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1); 822*7d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 823*7d8d0e25Snbeams ? 1 : 0 ); 824*7d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 825*7d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d, 826*7d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 827*7d8d0e25Snbeams interpargs); CeedChk(ierr); 828*7d8d0e25Snbeams } else if (dim == 3) { 829*7d8d0e25Snbeams CeedInt elemsPerBlock = 1; 830*7d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 831*7d8d0e25Snbeams ? 1 : 0 ); 832*7d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 833*7d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d, 834*7d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 835*7d8d0e25Snbeams interpargs); CeedChk(ierr); 836*7d8d0e25Snbeams } 837*7d8d0e25Snbeams } break; 838*7d8d0e25Snbeams case CEED_EVAL_GRAD: { 839*7d8d0e25Snbeams CeedInt P1d, Q1d; 840*7d8d0e25Snbeams ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr); 841*7d8d0e25Snbeams ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr); 842*7d8d0e25Snbeams CeedInt thread1d = CeedIntMax(Q1d, P1d); 843*7d8d0e25Snbeams ierr = CeedHipInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d, 844*7d8d0e25Snbeams Q1d, &data->c_B, &data->c_G); 845*7d8d0e25Snbeams CeedChk(ierr); 846*7d8d0e25Snbeams void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, 847*7d8d0e25Snbeams &data->c_G, &d_u, &d_v 848*7d8d0e25Snbeams }; 849*7d8d0e25Snbeams if (dim == 1) { 850*7d8d0e25Snbeams CeedInt elemsPerBlock = 32*thread1d > 256? 256/thread1d : 32; 851*7d8d0e25Snbeams elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1; 852*7d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 853*7d8d0e25Snbeams ? 1 : 0 ); 854*7d8d0e25Snbeams CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar); 855*7d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, 1, 856*7d8d0e25Snbeams elemsPerBlock, sharedMem, gradargs); 857*7d8d0e25Snbeams CeedChk(ierr); 858*7d8d0e25Snbeams } else if (dim == 2) { 859*7d8d0e25Snbeams const CeedInt optElems[7] = {0,32,8,6,4,2,6}; 860*7d8d0e25Snbeams // elemsPerBlock must be at least 1 861*7d8d0e25Snbeams CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1); 862*7d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 863*7d8d0e25Snbeams ? 1 : 0 ); 864*7d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 865*7d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d, 866*7d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 867*7d8d0e25Snbeams gradargs); CeedChk(ierr); 868*7d8d0e25Snbeams } else if (dim == 3) { 869*7d8d0e25Snbeams CeedInt elemsPerBlock = 1; 870*7d8d0e25Snbeams CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 871*7d8d0e25Snbeams ? 1 : 0 ); 872*7d8d0e25Snbeams CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar); 873*7d8d0e25Snbeams ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d, 874*7d8d0e25Snbeams ncomp*elemsPerBlock, sharedMem, 875*7d8d0e25Snbeams gradargs); CeedChk(ierr); 876*7d8d0e25Snbeams } 877*7d8d0e25Snbeams } break; 878*7d8d0e25Snbeams case CEED_EVAL_WEIGHT: { 879*7d8d0e25Snbeams CeedInt Q1d; 880*7d8d0e25Snbeams ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr); 881*7d8d0e25Snbeams void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v}; 882*7d8d0e25Snbeams if (dim == 1) { 883*7d8d0e25Snbeams const CeedInt optElems = 32/Q1d; 884*7d8d0e25Snbeams const CeedInt elemsPerBlock = optElems>0?optElems:1; 885*7d8d0e25Snbeams const CeedInt gridsize = nelem/elemsPerBlock + ( ( 886*7d8d0e25Snbeams nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 887*7d8d0e25Snbeams ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, 888*7d8d0e25Snbeams elemsPerBlock, 1, weightargs); 889*7d8d0e25Snbeams CeedChk(ierr); 890*7d8d0e25Snbeams } else if (dim == 2) { 891*7d8d0e25Snbeams const CeedInt optElems = 32/(Q1d*Q1d); 892*7d8d0e25Snbeams const CeedInt elemsPerBlock = optElems>0?optElems:1; 893*7d8d0e25Snbeams const CeedInt gridsize = nelem/elemsPerBlock + ( ( 894*7d8d0e25Snbeams nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 895*7d8d0e25Snbeams ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, 896*7d8d0e25Snbeams elemsPerBlock, weightargs); 897*7d8d0e25Snbeams CeedChk(ierr); 898*7d8d0e25Snbeams } else if (dim == 3) { 899*7d8d0e25Snbeams const CeedInt gridsize = nelem; 900*7d8d0e25Snbeams ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, Q1d, 901*7d8d0e25Snbeams weightargs); 902*7d8d0e25Snbeams CeedChk(ierr); 903*7d8d0e25Snbeams } 904*7d8d0e25Snbeams } break; 905*7d8d0e25Snbeams // LCOV_EXCL_START 906*7d8d0e25Snbeams // Evaluate the divergence to/from the quadrature points 907*7d8d0e25Snbeams case CEED_EVAL_DIV: 908*7d8d0e25Snbeams return CeedError(ceed, 1, "CEED_EVAL_DIV not supported"); 909*7d8d0e25Snbeams // Evaluate the curl to/from the quadrature points 910*7d8d0e25Snbeams case CEED_EVAL_CURL: 911*7d8d0e25Snbeams return CeedError(ceed, 1, "CEED_EVAL_CURL not supported"); 912*7d8d0e25Snbeams // Take no action, BasisApply should not have been called 913*7d8d0e25Snbeams case CEED_EVAL_NONE: 914*7d8d0e25Snbeams return CeedError(ceed, 1, 915*7d8d0e25Snbeams "CEED_EVAL_NONE does not make sense in this context"); 916*7d8d0e25Snbeams // LCOV_EXCL_STOP 917*7d8d0e25Snbeams } 918*7d8d0e25Snbeams 919*7d8d0e25Snbeams // Restore vectors 920*7d8d0e25Snbeams if (emode != CEED_EVAL_WEIGHT) { 921*7d8d0e25Snbeams ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr); 922*7d8d0e25Snbeams } 923*7d8d0e25Snbeams ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr); 924*7d8d0e25Snbeams return 0; 925*7d8d0e25Snbeams } 926*7d8d0e25Snbeams 927*7d8d0e25Snbeams //------------------------------------------------------------------------------ 928*7d8d0e25Snbeams // Destroy basis 929*7d8d0e25Snbeams //------------------------------------------------------------------------------ 930*7d8d0e25Snbeams static int CeedBasisDestroy_Hip_shared(CeedBasis basis) { 931*7d8d0e25Snbeams int ierr; 932*7d8d0e25Snbeams Ceed ceed; 933*7d8d0e25Snbeams ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr); 934*7d8d0e25Snbeams 935*7d8d0e25Snbeams CeedBasis_Hip_shared *data; 936*7d8d0e25Snbeams ierr = CeedBasisGetData(basis, &data); CeedChk(ierr); 937*7d8d0e25Snbeams 938*7d8d0e25Snbeams CeedChk_Hip(ceed, hipModuleUnload(data->module)); 939*7d8d0e25Snbeams 940*7d8d0e25Snbeams ierr = hipFree(data->d_qweight1d); CeedChk_Hip(ceed, ierr); 941*7d8d0e25Snbeams ierr = hipFree(data->d_interp1d); CeedChk_Hip(ceed, ierr); 942*7d8d0e25Snbeams ierr = hipFree(data->d_grad1d); CeedChk_Hip(ceed, ierr); 943*7d8d0e25Snbeams ierr = hipFree(data->d_collograd1d); CeedChk_Hip(ceed, ierr); 944*7d8d0e25Snbeams 945*7d8d0e25Snbeams ierr = CeedFree(&data); CeedChk(ierr); 946*7d8d0e25Snbeams 947*7d8d0e25Snbeams return 0; 948*7d8d0e25Snbeams } 949*7d8d0e25Snbeams 950*7d8d0e25Snbeams //------------------------------------------------------------------------------ 951*7d8d0e25Snbeams // Create tensor basis 952*7d8d0e25Snbeams //------------------------------------------------------------------------------ 953*7d8d0e25Snbeams int CeedBasisCreateTensorH1_Hip_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d, 954*7d8d0e25Snbeams const CeedScalar *interp1d, 955*7d8d0e25Snbeams const CeedScalar *grad1d, 956*7d8d0e25Snbeams const CeedScalar *qref1d, 957*7d8d0e25Snbeams const CeedScalar *qweight1d, 958*7d8d0e25Snbeams CeedBasis basis) { 959*7d8d0e25Snbeams int ierr; 960*7d8d0e25Snbeams Ceed ceed; 961*7d8d0e25Snbeams ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr); 962*7d8d0e25Snbeams CeedBasis_Hip_shared *data; 963*7d8d0e25Snbeams ierr = CeedCalloc(1, &data); CeedChk(ierr); 964*7d8d0e25Snbeams 965*7d8d0e25Snbeams // Copy basis data to GPU 966*7d8d0e25Snbeams const CeedInt qBytes = Q1d * sizeof(CeedScalar); 967*7d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Hip(ceed, ierr); 968*7d8d0e25Snbeams ierr = hipMemcpy(data->d_qweight1d, qweight1d, qBytes, 969*7d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 970*7d8d0e25Snbeams 971*7d8d0e25Snbeams const CeedInt iBytes = qBytes * P1d; 972*7d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Hip(ceed, ierr); 973*7d8d0e25Snbeams ierr = hipMemcpy(data->d_interp1d, interp1d, iBytes, 974*7d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 975*7d8d0e25Snbeams 976*7d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Hip(ceed, ierr); 977*7d8d0e25Snbeams ierr = hipMemcpy(data->d_grad1d, grad1d, iBytes, 978*7d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 979*7d8d0e25Snbeams 980*7d8d0e25Snbeams // Compute collocated gradient and copy to GPU 981*7d8d0e25Snbeams data->d_collograd1d = NULL; 982*7d8d0e25Snbeams if (dim == 3 && Q1d >= P1d) { 983*7d8d0e25Snbeams CeedScalar *collograd1d; 984*7d8d0e25Snbeams ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChk(ierr); 985*7d8d0e25Snbeams ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChk(ierr); 986*7d8d0e25Snbeams ierr = hipMalloc((void **)&data->d_collograd1d, qBytes * Q1d); 987*7d8d0e25Snbeams CeedChk_Hip(ceed, ierr); 988*7d8d0e25Snbeams ierr = hipMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d, 989*7d8d0e25Snbeams hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr); 990*7d8d0e25Snbeams ierr = CeedFree(&collograd1d); CeedChk(ierr); 991*7d8d0e25Snbeams } 992*7d8d0e25Snbeams 993*7d8d0e25Snbeams // Compile basis kernels 994*7d8d0e25Snbeams CeedInt ncomp; 995*7d8d0e25Snbeams ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr); 996*7d8d0e25Snbeams ierr = CeedCompileHip(ceed, kernelsShared, &data->module, 8, 997*7d8d0e25Snbeams "Q1D", Q1d, 998*7d8d0e25Snbeams "P1D", P1d, 999*7d8d0e25Snbeams "T1D", CeedIntMax(Q1d, P1d), 1000*7d8d0e25Snbeams "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ? 1001*7d8d0e25Snbeams Q1d : P1d, dim), 1002*7d8d0e25Snbeams "BASIS_DIM", dim, 1003*7d8d0e25Snbeams "BASIS_NCOMP", ncomp, 1004*7d8d0e25Snbeams "BASIS_ELEMSIZE", CeedIntPow(P1d, dim), 1005*7d8d0e25Snbeams "BASIS_NQPT", CeedIntPow(Q1d, dim) 1006*7d8d0e25Snbeams ); CeedChk(ierr); 1007*7d8d0e25Snbeams ierr = CeedGetKernelHip(ceed, data->module, "interp", &data->interp); 1008*7d8d0e25Snbeams CeedChk(ierr); 1009*7d8d0e25Snbeams ierr = CeedGetKernelHip(ceed, data->module, "grad", &data->grad); 1010*7d8d0e25Snbeams CeedChk(ierr); 1011*7d8d0e25Snbeams ierr = CeedGetKernelHip(ceed, data->module, "weight", &data->weight); 1012*7d8d0e25Snbeams CeedChk(ierr); 1013*7d8d0e25Snbeams 1014*7d8d0e25Snbeams ierr = CeedBasisSetData(basis, data); CeedChk(ierr); 1015*7d8d0e25Snbeams 1016*7d8d0e25Snbeams // Register backend functions 1017*7d8d0e25Snbeams ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply", 1018*7d8d0e25Snbeams CeedBasisApplyTensor_Hip_shared); 1019*7d8d0e25Snbeams CeedChk(ierr); 1020*7d8d0e25Snbeams ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", 1021*7d8d0e25Snbeams CeedBasisDestroy_Hip_shared); CeedChk(ierr); 1022*7d8d0e25Snbeams return 0; 1023*7d8d0e25Snbeams } 1024*7d8d0e25Snbeams //------------------------------------------------------------------------------ 1025