1c532df63SYohann // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC. 2c532df63SYohann // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707. 3c532df63SYohann // All Rights reserved. See files LICENSE and NOTICE for details. 4c532df63SYohann // 5c532df63SYohann // This file is part of CEED, a collection of benchmarks, miniapps, software 6c532df63SYohann // libraries and APIs for efficient high-order finite element and spectral 7c532df63SYohann // element discretizations for exascale applications. For more information and 8c532df63SYohann // source code availability see http://github.com/ceed. 9c532df63SYohann // 10c532df63SYohann // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 11c532df63SYohann // a collaborative effort of two U.S. Department of Energy organizations (Office 12c532df63SYohann // of Science and the National Nuclear Security Administration) responsible for 13c532df63SYohann // the planning and preparation of a capable exascale ecosystem, including 14c532df63SYohann // software, applications, hardware, advanced system engineering and early 15c532df63SYohann // testbed platforms, in support of the nation's exascale computing imperative. 16c532df63SYohann 17c532df63SYohann #include <ceed-backend.h> 18c532df63SYohann #include <ceed.h> 19c532df63SYohann #include "ceed-cuda-shared.h" 20c532df63SYohann #include "../cuda/ceed-cuda.h" 21c532df63SYohann 22c532df63SYohann //********************* 23c532df63SYohann // shared mem kernels 24c532df63SYohann static const char *kernelsShared = QUOTE( 25c532df63SYohann 26c532df63SYohann inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) { 27c532df63SYohann for (int i = 0; i < Q1D; i++) 28c532df63SYohann r_V[i] += r_U[i]; 29c532df63SYohann } 30c532df63SYohann 31c532df63SYohann ////////// 32c532df63SYohann // 1D // 33c532df63SYohann ////////// 34c532df63SYohann 35c532df63SYohann inline __device__ void readDofs1d(const int elem, const int tidx, 36d94769d2SYohann Dudouit const int tidy, const int tidz,const int comp, 37c532df63SYohann const int nelem, const CeedScalar *d_U, CeedScalar *slice) { 38c532df63SYohann for (int i = 0; i < P1D; i++) 39d94769d2SYohann Dudouit slice[i+tidz*Q1D] = d_U[i + comp*P1D + elem*BASIS_NCOMP*P1D]; 40c532df63SYohann for (int i = P1D; i < Q1D; i++) 41d94769d2SYohann Dudouit slice[i+tidz*Q1D] = 0.0; 42c532df63SYohann } 43c532df63SYohann 44c532df63SYohann inline __device__ void writeDofs1d(const int elem, const int tidx, 45c532df63SYohann const int tidy, const int comp, 46c532df63SYohann const int nelem, const CeedScalar &r_V, CeedScalar *d_V) { 47c532df63SYohann if (tidx<P1D) { 48c532df63SYohann d_V[tidx + comp*P1D + elem*BASIS_NCOMP*P1D] = r_V; 49c532df63SYohann } 50c532df63SYohann } 51c532df63SYohann 52c532df63SYohann inline __device__ void readQuads1d(const int elem, const int tidx, 53d94769d2SYohann Dudouit const int tidy, const int tidz, const int comp, 54c532df63SYohann const int dim, const int nelem, const CeedScalar *d_U, CeedScalar *slice) { 55c532df63SYohann for (int i = 0; i < Q1D; i++) 56*4d537eeaSYohann slice[i+tidz*Q1D] = d_U[i + elem*Q1D + comp*Q1D*nelem + 57*4d537eeaSYohann dim*BASIS_NCOMP*nelem*Q1D]; 58c532df63SYohann } 59c532df63SYohann 60c532df63SYohann inline __device__ void writeQuads1d(const int elem, const int tidx, 61c532df63SYohann const int tidy, const int comp, 62c532df63SYohann const int dim, const int nelem, const CeedScalar &r_V, CeedScalar *d_V) { 63c532df63SYohann d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V; 64c532df63SYohann } 65c532df63SYohann 66c532df63SYohann inline __device__ void ContractX1d(CeedScalar *slice, const int tidx, 67d94769d2SYohann Dudouit const int tidy, const int tidz, 68c532df63SYohann const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 69c532df63SYohann V = 0.0; 70c532df63SYohann for (int i = 0; i < P1D; ++i) { 71d94769d2SYohann Dudouit V += B[i + tidx*P1D] * slice[i+tidz*Q1D];//contract x direction 72c532df63SYohann } 73c532df63SYohann } 74c532df63SYohann 75c532df63SYohann inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx, 76d94769d2SYohann Dudouit const int tidy, const int tidz, 77c532df63SYohann const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 78c532df63SYohann V = 0.0; 79c532df63SYohann for (int i = 0; i < Q1D; ++i) { 80d94769d2SYohann Dudouit V += B[tidx + i*P1D] * slice[i+tidz*Q1D];//contract x direction 81c532df63SYohann } 82c532df63SYohann } 83c532df63SYohann 84c532df63SYohann inline __device__ void interp1d(const CeedInt nelem, const int transpose, 85c532df63SYohann const CeedScalar *c_B, const CeedScalar *__restrict__ d_U, 86c532df63SYohann CeedScalar *__restrict__ d_V, 87c532df63SYohann CeedScalar *slice) { 88c532df63SYohann CeedScalar r_V; 89c532df63SYohann CeedScalar r_t; 90c532df63SYohann 91c532df63SYohann const int tidx = threadIdx.x; 92c532df63SYohann const int tidy = threadIdx.y; 93d94769d2SYohann Dudouit const int tidz = threadIdx.z; 94c532df63SYohann 95c532df63SYohann 96c532df63SYohann for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 97c532df63SYohann elem += gridDim.x*blockDim.z) { 98c532df63SYohann for(int comp=0; comp<BASIS_NCOMP; comp++) { 99c532df63SYohann if(!transpose) { 100d94769d2SYohann Dudouit readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 101d94769d2SYohann Dudouit ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 102c532df63SYohann writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 103c532df63SYohann } else { 104d94769d2SYohann Dudouit readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice); 105d94769d2SYohann Dudouit ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 106c532df63SYohann writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 107c532df63SYohann } 108c532df63SYohann } 109c532df63SYohann } 110c532df63SYohann } 111c532df63SYohann 112c532df63SYohann inline __device__ void grad1d(const CeedInt nelem, const int transpose, 113c532df63SYohann const CeedScalar *c_B, const CeedScalar *c_G, 114c532df63SYohann const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V, 115c532df63SYohann CeedScalar *slice) { 116c532df63SYohann CeedScalar r_U; 117c532df63SYohann CeedScalar r_V; 118c532df63SYohann 119c532df63SYohann const int tidx = threadIdx.x; 120d94769d2SYohann Dudouit const int tidy = threadIdx.y; 121d94769d2SYohann Dudouit const int tidz = threadIdx.z; 122c532df63SYohann int dim; 123c532df63SYohann 124c532df63SYohann for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 125c532df63SYohann elem += gridDim.x*blockDim.z) { 126c532df63SYohann for(int comp=0; comp<BASIS_NCOMP; comp++) { 127c532df63SYohann if(!transpose) { 128d94769d2SYohann Dudouit readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 129d94769d2SYohann Dudouit ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 130c532df63SYohann dim = 0; 131c532df63SYohann writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 132c532df63SYohann } else { 133c532df63SYohann dim = 0; 134d94769d2SYohann Dudouit readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice); 135d94769d2SYohann Dudouit ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 136c532df63SYohann writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 137c532df63SYohann } 138c532df63SYohann } 139c532df63SYohann } 140c532df63SYohann } 141c532df63SYohann ////////// 142c532df63SYohann // 2D // 143c532df63SYohann ////////// 144c532df63SYohann 145c532df63SYohann inline __device__ void readDofs2d(const int elem, const int tidx, 146c532df63SYohann const int tidy, const int comp, 147c532df63SYohann const int nelem, const CeedScalar *d_U, CeedScalar &U) { 148c532df63SYohann U = (tidx<P1D 149c532df63SYohann && tidy<P1D) ? d_U[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D ] : 150c532df63SYohann 0.0; 151c532df63SYohann } 152c532df63SYohann 153c532df63SYohann inline __device__ void writeDofs2d(const int elem, const int tidx, 154c532df63SYohann const int tidy, const int comp, 155c532df63SYohann const int nelem, const CeedScalar &r_V, CeedScalar *d_V) { 156c532df63SYohann if (tidx<P1D && tidy<P1D) { 157c532df63SYohann d_V[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D ] = r_V; 158c532df63SYohann } 159c532df63SYohann } 160c532df63SYohann 161c532df63SYohann inline __device__ void readQuads2d(const int elem, const int tidx, 162c532df63SYohann const int tidy, const int comp, 163c532df63SYohann const int dim, const int nelem, const CeedScalar *d_U, CeedScalar &U ) { 164c532df63SYohann U = d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 165c532df63SYohann dim*BASIS_NCOMP*nelem*Q1D*Q1D]; 166c532df63SYohann } 167c532df63SYohann 168c532df63SYohann inline __device__ void writeQuads2d(const int elem, const int tidx, 169c532df63SYohann const int tidy, const int comp, 170c532df63SYohann const int dim, const int nelem, const CeedScalar &r_V, CeedScalar *d_V) { 171c532df63SYohann d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 172c532df63SYohann dim*BASIS_NCOMP*nelem*Q1D*Q1D ] = r_V; 173c532df63SYohann } 174c532df63SYohann 175c532df63SYohann inline __device__ void ContractX2d(CeedScalar *slice, const int tidx, 1764247ecf3SYohann Dudouit const int tidy, const int tidz, 177c532df63SYohann const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 1784247ecf3SYohann Dudouit slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U; 179c532df63SYohann __syncthreads(); 180c532df63SYohann V = 0.0; 181c532df63SYohann for (int i = 0; i < P1D; ++i) { 1824247ecf3SYohann Dudouit V += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction 183c532df63SYohann } 184c532df63SYohann __syncthreads(); 185c532df63SYohann } 186c532df63SYohann 187c532df63SYohann inline __device__ void ContractY2d(CeedScalar *slice, const int tidx, 1884247ecf3SYohann Dudouit const int tidy, const int tidz, 189c532df63SYohann const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 1904247ecf3SYohann Dudouit slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U; 191c532df63SYohann __syncthreads(); 192c532df63SYohann V = 0.0; 193c532df63SYohann for (int i = 0; i < P1D; ++i) { 1944247ecf3SYohann Dudouit V += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction 195c532df63SYohann } 196c532df63SYohann __syncthreads(); 197c532df63SYohann } 198c532df63SYohann 199c532df63SYohann inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx, 2004247ecf3SYohann Dudouit const int tidy, const int tidz, 201c532df63SYohann const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 2024247ecf3SYohann Dudouit slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U; 203c532df63SYohann __syncthreads(); 204c532df63SYohann V = 0.0; 205c532df63SYohann if (tidy<P1D) { 206c532df63SYohann for (int i = 0; i < Q1D; ++i) { 2074247ecf3SYohann Dudouit V += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction 208c532df63SYohann } 209c532df63SYohann } 210c532df63SYohann __syncthreads(); 211c532df63SYohann } 212c532df63SYohann 213c532df63SYohann inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx, 2144247ecf3SYohann Dudouit const int tidy, const int tidz, 215c532df63SYohann const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 2164247ecf3SYohann Dudouit slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U; 217c532df63SYohann __syncthreads(); 218c532df63SYohann V = 0.0; 219c532df63SYohann if (tidx<P1D) { 220c532df63SYohann for (int i = 0; i < Q1D; ++i) { 2214247ecf3SYohann Dudouit V += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction 222c532df63SYohann } 223c532df63SYohann } 224c532df63SYohann __syncthreads(); 225c532df63SYohann } 226c532df63SYohann 227c532df63SYohann inline __device__ void interp2d(const CeedInt nelem, const int transpose, 228c532df63SYohann const CeedScalar *c_B, const CeedScalar *__restrict__ d_U, 229c532df63SYohann CeedScalar *__restrict__ d_V, 230c532df63SYohann CeedScalar *slice) { 231c532df63SYohann CeedScalar r_V; 232c532df63SYohann CeedScalar r_t; 233c532df63SYohann 234c532df63SYohann const int tidx = threadIdx.x; 235c532df63SYohann const int tidy = threadIdx.y; 2364247ecf3SYohann Dudouit const int tidz = threadIdx.z; 2374247ecf3SYohann Dudouit const int blockElem = tidz/BASIS_NCOMP; 2384247ecf3SYohann Dudouit const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 2394247ecf3SYohann Dudouit const int comp = tidz%BASIS_NCOMP; 240c532df63SYohann 2414247ecf3SYohann Dudouit for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 2424247ecf3SYohann Dudouit elem += gridDim.x*elemsPerBlock) { 2434247ecf3SYohann Dudouit const int comp = tidz%BASIS_NCOMP; 244c532df63SYohann r_V = 0.0; 245c532df63SYohann r_t = 0.0; 246c532df63SYohann if(!transpose) { 247c532df63SYohann readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V); 2484247ecf3SYohann Dudouit ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 2494247ecf3SYohann Dudouit ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 250c532df63SYohann writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 251c532df63SYohann } else { 252c532df63SYohann readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 2534247ecf3SYohann Dudouit ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 2544247ecf3SYohann Dudouit ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 255c532df63SYohann writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 256c532df63SYohann } 257c532df63SYohann } 258c532df63SYohann } 259c532df63SYohann 260c532df63SYohann inline __device__ void grad2d(const CeedInt nelem, const int transpose, 261c532df63SYohann const CeedScalar *c_B, const CeedScalar *c_G, 262c532df63SYohann const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V, 263c532df63SYohann CeedScalar *slice) { 264c532df63SYohann CeedScalar r_U; 265c532df63SYohann CeedScalar r_V; 266c532df63SYohann CeedScalar r_t; 267c532df63SYohann 268c532df63SYohann const int tidx = threadIdx.x; 269c532df63SYohann const int tidy = threadIdx.y; 2704247ecf3SYohann Dudouit const int tidz = threadIdx.z; 2714247ecf3SYohann Dudouit const int blockElem = tidz/BASIS_NCOMP; 2724247ecf3SYohann Dudouit const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 2734247ecf3SYohann Dudouit const int comp = tidz%BASIS_NCOMP; 274c532df63SYohann int dim; 275c532df63SYohann 2764247ecf3SYohann Dudouit for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 2774247ecf3SYohann Dudouit elem += gridDim.x*elemsPerBlock) { 278c532df63SYohann if(!transpose) { 279c532df63SYohann readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U); 2804247ecf3SYohann Dudouit ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 2814247ecf3SYohann Dudouit ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 282c532df63SYohann dim = 0; 283c532df63SYohann writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 2844247ecf3SYohann Dudouit ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 2854247ecf3SYohann Dudouit ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 286c532df63SYohann dim = 1; 287c532df63SYohann writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 288c532df63SYohann } else { 289c532df63SYohann dim = 0; 290c532df63SYohann readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 2914247ecf3SYohann Dudouit ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 2924247ecf3SYohann Dudouit ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 293c532df63SYohann dim = 1; 294c532df63SYohann readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 2954247ecf3SYohann Dudouit ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 2964247ecf3SYohann Dudouit ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 297c532df63SYohann r_V+=r_U; 298c532df63SYohann writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 299c532df63SYohann } 300c532df63SYohann } 301c532df63SYohann } 302c532df63SYohann ////////// 303c532df63SYohann // 3D // 304c532df63SYohann ////////// 305c532df63SYohann 306c532df63SYohann inline __device__ void readDofs3d(const int elem, const int tidx, 307c532df63SYohann const int tidy, const int comp, 308c532df63SYohann const int nelem, const CeedScalar *d_U, CeedScalar *r_U) { 309c532df63SYohann for (int i = 0; i < P1D; i++) 310c532df63SYohann r_U[i] = (tidx<P1D 311c532df63SYohann && tidy<P1D) ? d_U[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D + 312c532df63SYohann elem*BASIS_NCOMP*P1D*P1D*P1D ] : 0.0; 313c532df63SYohann for (int i = P1D; i < Q1D; i++) 314c532df63SYohann r_U[i] = 0.0; 315c532df63SYohann } 316c532df63SYohann 317c532df63SYohann inline __device__ void readQuads3d(const int elem, const int tidx, 318c532df63SYohann const int tidy, const int comp, 319c532df63SYohann const int dim, const int nelem, const CeedScalar *d_U, CeedScalar *r_U) { 320c532df63SYohann for (int i = 0; i < Q1D; i++) 321c532df63SYohann r_U[i] = d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + 322c532df63SYohann comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D]; 323c532df63SYohann } 324c532df63SYohann 325c532df63SYohann inline __device__ void writeDofs3d(const int elem, const int tidx, 326c532df63SYohann const int tidy, const int comp, 327c532df63SYohann const int nelem, const CeedScalar *r_V, CeedScalar *d_V) { 328c532df63SYohann if (tidx<P1D && tidy<P1D) { 329c532df63SYohann for (int i = 0; i < P1D; i++) 330c532df63SYohann d_V[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D + 331c532df63SYohann elem*BASIS_NCOMP*P1D*P1D*P1D ] = r_V[i]; 332c532df63SYohann } 333c532df63SYohann } 334c532df63SYohann 335c532df63SYohann inline __device__ void writeQuads3d(const int elem, const int tidx, 336c532df63SYohann const int tidy, const int comp, 337c532df63SYohann const int dim, const int nelem, const CeedScalar *r_V, CeedScalar *d_V) { 338c532df63SYohann for (int i = 0; i < Q1D; i++) 339c532df63SYohann d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem + 340c532df63SYohann dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D ] = r_V[i]; 341c532df63SYohann } 342c532df63SYohann 343c532df63SYohann inline __device__ void ContractX3d(CeedScalar *slice, const int tidx, 344698ebc35SYohann Dudouit const int tidy, const int tidz, 345c532df63SYohann const CeedScalar *U, const CeedScalar *B, CeedScalar *V) { 346c532df63SYohann for (int k = 0; k < P1D; ++k) { 347698ebc35SYohann Dudouit slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k]; 348c532df63SYohann __syncthreads(); 349c532df63SYohann V[k] = 0.0; 350c532df63SYohann for (int i = 0; i < P1D; ++i) { 351*4d537eeaSYohann V[k] += B[i + tidx*P1D] * slice[i + tidy*Q1D + 352*4d537eeaSYohann tidz*Q1D*Q1D];//contract x direction 353c532df63SYohann } 354c532df63SYohann __syncthreads(); 355c532df63SYohann } 356c532df63SYohann } 357c532df63SYohann 358c532df63SYohann inline __device__ void ContractY3d(CeedScalar *slice, const int tidx, 359698ebc35SYohann Dudouit const int tidy, const int tidz, 360c532df63SYohann const CeedScalar *U, const CeedScalar *B, CeedScalar *V) { 361c532df63SYohann for (int k = 0; k < P1D; ++k) { 362698ebc35SYohann Dudouit slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k]; 363c532df63SYohann __syncthreads(); 364c532df63SYohann V[k] = 0.0; 365c532df63SYohann for (int i = 0; i < P1D; ++i) { 366*4d537eeaSYohann V[k] += B[i + tidy*P1D] * slice[tidx + i*Q1D + 367*4d537eeaSYohann tidz*Q1D*Q1D];//contract y direction 368c532df63SYohann } 369c532df63SYohann __syncthreads(); 370c532df63SYohann } 371c532df63SYohann } 372c532df63SYohann 373c532df63SYohann inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx, 374698ebc35SYohann Dudouit const int tidy, const int tidz, 375c532df63SYohann const CeedScalar *U, const CeedScalar *B, CeedScalar *V) { 376c532df63SYohann for (int k = 0; k < Q1D; ++k) { 377c532df63SYohann V[k] = 0.0; 378c532df63SYohann for (int i = 0; i < P1D; ++i) { 379c532df63SYohann V[k] += B[i + k*P1D] * U[i];//contract z direction 380c532df63SYohann } 381c532df63SYohann } 382c532df63SYohann } 383c532df63SYohann 384c532df63SYohann inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx, 385698ebc35SYohann Dudouit const int tidy, const int tidz, 386c532df63SYohann const CeedScalar *U, const CeedScalar *B, CeedScalar *V) { 387c532df63SYohann for (int k = 0; k < Q1D; ++k) { 388c532df63SYohann V[k] = 0.0; 389c532df63SYohann if (k<P1D) { 390c532df63SYohann for (int i = 0; i < Q1D; ++i) { 391c532df63SYohann V[k] += B[k + i*P1D] * U[i];//contract z direction 392c532df63SYohann } 393c532df63SYohann } 394c532df63SYohann } 395c532df63SYohann } 396c532df63SYohann 397c532df63SYohann inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx, 398698ebc35SYohann Dudouit const int tidy, const int tidz, 399c532df63SYohann const CeedScalar *U, const CeedScalar *B, CeedScalar *V) { 400c532df63SYohann for (int k = 0; k < P1D; ++k) { 401698ebc35SYohann Dudouit slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k]; 402c532df63SYohann __syncthreads(); 403c532df63SYohann V[k] = 0.0; 404c532df63SYohann if (tidy<P1D) { 405c532df63SYohann for (int i = 0; i < Q1D; ++i) { 406*4d537eeaSYohann V[k] += B[tidy + i*P1D] * slice[tidx + i*Q1D + 407*4d537eeaSYohann tidz*Q1D*Q1D];//contract y direction 408c532df63SYohann } 409c532df63SYohann } 410c532df63SYohann __syncthreads(); 411c532df63SYohann } 412c532df63SYohann } 413c532df63SYohann 414c532df63SYohann inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx, 415698ebc35SYohann Dudouit const int tidy, const int tidz, 416c532df63SYohann const CeedScalar *U, const CeedScalar *B, CeedScalar *V) { 417c532df63SYohann for (int k = 0; k < P1D; ++k) { 418698ebc35SYohann Dudouit slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k]; 419c532df63SYohann __syncthreads(); 420c532df63SYohann V[k] = 0.0; 421c532df63SYohann if (tidx<P1D) { 422c532df63SYohann for (int i = 0; i < Q1D; ++i) { 423*4d537eeaSYohann V[k] += B[tidx + i*P1D] * slice[i + tidy*Q1D + 424*4d537eeaSYohann tidz*Q1D*Q1D];//contract x direction 425c532df63SYohann } 426c532df63SYohann } 427c532df63SYohann __syncthreads(); 428c532df63SYohann } 429c532df63SYohann } 430c532df63SYohann 431c532df63SYohann inline __device__ void interp3d(const CeedInt nelem, const int transpose, 432c532df63SYohann const CeedScalar *c_B, const CeedScalar *__restrict__ d_U, 433c532df63SYohann CeedScalar *__restrict__ d_V, 434c532df63SYohann CeedScalar *slice) { 435c532df63SYohann CeedScalar r_V[Q1D]; 436c532df63SYohann CeedScalar r_t[Q1D]; 437c532df63SYohann 438c532df63SYohann const int tidx = threadIdx.x; 439c532df63SYohann const int tidy = threadIdx.y; 440698ebc35SYohann Dudouit const int tidz = threadIdx.z; 441698ebc35SYohann Dudouit const int blockElem = tidz/BASIS_NCOMP; 442698ebc35SYohann Dudouit const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 443698ebc35SYohann Dudouit const int comp = tidz%BASIS_NCOMP; 444c532df63SYohann 445698ebc35SYohann Dudouit for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 446698ebc35SYohann Dudouit elem += gridDim.x*elemsPerBlock) { 447c532df63SYohann for (int i = 0; i < Q1D; ++i) { 448c532df63SYohann r_V[i] = 0.0; 449c532df63SYohann r_t[i] = 0.0; 450c532df63SYohann } 451c532df63SYohann if(!transpose) { 452c532df63SYohann readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V); 453698ebc35SYohann Dudouit ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 454698ebc35SYohann Dudouit ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 455698ebc35SYohann Dudouit ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 456c532df63SYohann writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V); 457c532df63SYohann } else { 458c532df63SYohann readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 459698ebc35SYohann Dudouit ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 460698ebc35SYohann Dudouit ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 461698ebc35SYohann Dudouit ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 462c532df63SYohann writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V); 463c532df63SYohann } 464c532df63SYohann } 465c532df63SYohann } 466c532df63SYohann 467c532df63SYohann inline __device__ void grad3d(const CeedInt nelem, const int transpose, 468c532df63SYohann const CeedScalar *c_B, const CeedScalar *c_G, 469c532df63SYohann const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V, 470c532df63SYohann CeedScalar *slice) { 471c532df63SYohann //use P1D for one of these 472c532df63SYohann CeedScalar r_U[Q1D]; 473c532df63SYohann CeedScalar r_V[Q1D]; 474c532df63SYohann CeedScalar r_t[Q1D]; 475c532df63SYohann 476c532df63SYohann const int tidx = threadIdx.x; 477c532df63SYohann const int tidy = threadIdx.y; 478698ebc35SYohann Dudouit const int tidz = threadIdx.z; 479698ebc35SYohann Dudouit const int blockElem = tidz/BASIS_NCOMP; 480698ebc35SYohann Dudouit const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 481698ebc35SYohann Dudouit const int comp = tidz%BASIS_NCOMP; 482c532df63SYohann int dim; 483c532df63SYohann 484698ebc35SYohann Dudouit for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 485698ebc35SYohann Dudouit elem += gridDim.x*elemsPerBlock) { 486c532df63SYohann if(!transpose) { 487c532df63SYohann readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U); 488698ebc35SYohann Dudouit ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 489698ebc35SYohann Dudouit ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 490698ebc35SYohann Dudouit ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 491c532df63SYohann dim = 0; 492c532df63SYohann writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 493698ebc35SYohann Dudouit ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V); 494698ebc35SYohann Dudouit ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t); 495698ebc35SYohann Dudouit ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 496c532df63SYohann dim = 1; 497c532df63SYohann writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 498698ebc35SYohann Dudouit ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V); 499698ebc35SYohann Dudouit ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 500698ebc35SYohann Dudouit ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 501c532df63SYohann dim = 2; 502c532df63SYohann writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 503c532df63SYohann } else { 504c532df63SYohann dim = 0; 505c532df63SYohann readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 506698ebc35SYohann Dudouit ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 507698ebc35SYohann Dudouit ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 508698ebc35SYohann Dudouit ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 509c532df63SYohann dim = 1; 510c532df63SYohann readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 511698ebc35SYohann Dudouit ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 512698ebc35SYohann Dudouit ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U); 513698ebc35SYohann Dudouit ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 514c532df63SYohann add(r_V, r_t); 515c532df63SYohann dim = 2; 516c532df63SYohann readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 517698ebc35SYohann Dudouit ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 518698ebc35SYohann Dudouit ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 519698ebc35SYohann Dudouit ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 520c532df63SYohann add(r_V, r_t); 521c532df63SYohann writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V); 522c532df63SYohann } 523c532df63SYohann } 524c532df63SYohann } 525c532df63SYohann 526c532df63SYohann ///////////// 527c532df63SYohann // Kernels // 528c532df63SYohann ///////////// 529c532df63SYohann extern "C" __global__ void interp(const CeedInt nelem, const int transpose, 530c532df63SYohann const CeedScalar *c_B, const CeedScalar *__restrict__ d_U, 531c532df63SYohann CeedScalar *__restrict__ d_V) { 532074be161SYohann Dudouit extern __shared__ double slice[]; 533c532df63SYohann if (BASIS_DIM==1) { 534c532df63SYohann interp1d(nelem, transpose, c_B, d_U, d_V, slice); 535c532df63SYohann } else if (BASIS_DIM==2) { 536c532df63SYohann interp2d(nelem, transpose, c_B, d_U, d_V, slice); 537c532df63SYohann } else if (BASIS_DIM==3) { 538c532df63SYohann interp3d(nelem, transpose, c_B, d_U, d_V, slice); 539c532df63SYohann } 540c532df63SYohann } 541c532df63SYohann 542c532df63SYohann extern "C" __global__ void grad(const CeedInt nelem, const int transpose, 543c532df63SYohann const CeedScalar *c_B, const CeedScalar *c_G, 544c532df63SYohann const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) { 545074be161SYohann Dudouit extern __shared__ double slice[]; 546c532df63SYohann if (BASIS_DIM==1) { 547c532df63SYohann grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 548c532df63SYohann } else if (BASIS_DIM==2) { 549c532df63SYohann grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 550c532df63SYohann } else if (BASIS_DIM==3) { 551c532df63SYohann grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 552c532df63SYohann } 553c532df63SYohann } 554c532df63SYohann 555c532df63SYohann ///////////// 556c532df63SYohann // Weights // 557c532df63SYohann ///////////// 558c532df63SYohann __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d, 559c532df63SYohann CeedScalar *w) { 560074be161SYohann Dudouit const int tid = threadIdx.x; 561074be161SYohann Dudouit const CeedScalar weight = qweight1d[tid]; 562074be161SYohann Dudouit for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem; 563074be161SYohann Dudouit elem += gridDim.x*blockDim.y) { 564074be161SYohann Dudouit const int ind = elem*Q1D + tid; 565074be161SYohann Dudouit w[ind] = weight; 566c532df63SYohann } 567c532df63SYohann } 568c532df63SYohann 569c532df63SYohann __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d, 570c532df63SYohann CeedScalar *w) { 571074be161SYohann Dudouit const int i = threadIdx.x; 572074be161SYohann Dudouit const int j = threadIdx.y; 573074be161SYohann Dudouit const CeedScalar weight = qweight1d[i]*qweight1d[j]; 574074be161SYohann Dudouit for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 575074be161SYohann Dudouit elem += gridDim.x*blockDim.z) { 576074be161SYohann Dudouit const int ind = elem*Q1D*Q1D + i + j*Q1D; 577074be161SYohann Dudouit w[ind] = weight; 578c532df63SYohann } 579c532df63SYohann } 580c532df63SYohann 581c532df63SYohann __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d, 582c532df63SYohann CeedScalar *w) { 583074be161SYohann Dudouit const int i = threadIdx.x; 584074be161SYohann Dudouit const int j = threadIdx.y; 585074be161SYohann Dudouit const int k = threadIdx.z; 586074be161SYohann Dudouit const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k]; 587074be161SYohann Dudouit for (int e = blockIdx.x; e < nelem; e += gridDim.x) { 588074be161SYohann Dudouit const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D; 589074be161SYohann Dudouit w[ind] = weight; 590c532df63SYohann } 591c532df63SYohann } 592c532df63SYohann 593c532df63SYohann extern "C" __global__ void weight(const CeedInt nelem, 594c532df63SYohann const CeedScalar *__restrict__ qweight1d, CeedScalar *__restrict__ v) { 595c532df63SYohann if (BASIS_DIM==1) { 596c532df63SYohann weight1d(nelem, qweight1d, v); 597c532df63SYohann } else if (BASIS_DIM==2) { 598c532df63SYohann weight2d(nelem, qweight1d, v); 599c532df63SYohann } else if (BASIS_DIM==3) { 600c532df63SYohann weight3d(nelem, qweight1d, v); 601c532df63SYohann } 602c532df63SYohann } 603c532df63SYohann 604c532df63SYohann ); 605c532df63SYohann 606c532df63SYohann int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d, 607c532df63SYohann CeedScalar **c_B); 608c532df63SYohann int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d, 609c532df63SYohann CeedInt Q1d, CeedScalar **c_B_ptr, CeedScalar **c_G_ptr); 610c532df63SYohann 611c532df63SYohann int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem, 612c532df63SYohann CeedTransposeMode tmode, 613c532df63SYohann CeedEvalMode emode, CeedVector u, CeedVector v) { 614c532df63SYohann int ierr; 615c532df63SYohann Ceed ceed; 616c532df63SYohann ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr); 617c532df63SYohann Ceed_Cuda_shared *ceed_Cuda; 618c532df63SYohann CeedGetData(ceed, (void *) &ceed_Cuda); CeedChk(ierr); 619c532df63SYohann CeedBasis_Cuda_shared *data; 620c532df63SYohann CeedBasisGetData(basis, (void *)&data); CeedChk(ierr); 621c532df63SYohann const CeedInt transpose = tmode == CEED_TRANSPOSE; 6224247ecf3SYohann Dudouit CeedInt dim, ncomp; 623074be161SYohann Dudouit ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr); 6244247ecf3SYohann Dudouit ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr); 625c532df63SYohann 626c532df63SYohann const CeedScalar *d_u; 627c532df63SYohann CeedScalar *d_v; 628c532df63SYohann if(emode!=CEED_EVAL_WEIGHT) { 629c532df63SYohann ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr); 630c532df63SYohann } 631c532df63SYohann ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr); 632c532df63SYohann 633c532df63SYohann if (tmode == CEED_TRANSPOSE) { 634c532df63SYohann CeedInt length; 635c532df63SYohann ierr = CeedVectorGetLength(v, &length); CeedChk(ierr); 636c532df63SYohann ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr); 637c532df63SYohann } 638c532df63SYohann if (emode == CEED_EVAL_INTERP) { 639c532df63SYohann CeedInt P1d, Q1d; 640c532df63SYohann ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr); 641c532df63SYohann ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr); 642c532df63SYohann ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B); 643c532df63SYohann CeedChk(ierr); 644c532df63SYohann void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, &d_u, &d_v}; 645*4d537eeaSYohann if (dim==1) { 646d94769d2SYohann Dudouit CeedInt elemsPerBlock = 32; 647*4d537eeaSYohann CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 648*4d537eeaSYohann ? 1 : 0 ); 649d94769d2SYohann Dudouit CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar); 650*4d537eeaSYohann ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, 1, 651*4d537eeaSYohann elemsPerBlock, sharedMem, 652c532df63SYohann interpargs); 653c532df63SYohann CeedChk(ierr); 654074be161SYohann Dudouit } else if (dim==2) { 6554247ecf3SYohann Dudouit const CeedInt optElems[7] = {0,32,8,6,4,2,8}; 6564247ecf3SYohann Dudouit CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1; 657*4d537eeaSYohann CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 658*4d537eeaSYohann ? 1 : 0 ); 6594247ecf3SYohann Dudouit CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar); 660*4d537eeaSYohann ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d, 661*4d537eeaSYohann ncomp*elemsPerBlock, sharedMem, 662074be161SYohann Dudouit interpargs); 663074be161SYohann Dudouit CeedChk(ierr); 664074be161SYohann Dudouit } else if (dim==3) { 6653f63d318SYohann Dudouit CeedInt elemsPerBlock = 1; 666*4d537eeaSYohann CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 667*4d537eeaSYohann ? 1 : 0 ); 668698ebc35SYohann Dudouit CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar); 669*4d537eeaSYohann ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d, 670*4d537eeaSYohann ncomp*elemsPerBlock, sharedMem, 671074be161SYohann Dudouit interpargs); 672074be161SYohann Dudouit CeedChk(ierr); 673074be161SYohann Dudouit } 674c532df63SYohann } else if (emode == CEED_EVAL_GRAD) { 675c532df63SYohann CeedInt P1d, Q1d; 676c532df63SYohann ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr); 677c532df63SYohann ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr); 678c532df63SYohann ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d, 679c532df63SYohann Q1d, &data->c_B, &data->c_G); 680c532df63SYohann CeedChk(ierr); 681c532df63SYohann void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, &data->c_G, &d_u, &d_v}; 682*4d537eeaSYohann if (dim==1) { 683d94769d2SYohann Dudouit CeedInt elemsPerBlock = 32; 684*4d537eeaSYohann CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 685*4d537eeaSYohann ? 1 : 0 ); 686d94769d2SYohann Dudouit CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar); 687*4d537eeaSYohann ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, 1, elemsPerBlock, 688*4d537eeaSYohann sharedMem, 689c532df63SYohann gradargs); 690c532df63SYohann CeedChk(ierr); 691074be161SYohann Dudouit } else if (dim==2) { 6924247ecf3SYohann Dudouit const CeedInt optElems[7] = {0,32,8,6,4,2,8}; 6934247ecf3SYohann Dudouit CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1; 694*4d537eeaSYohann CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 695*4d537eeaSYohann ? 1 : 0 ); 6964247ecf3SYohann Dudouit CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar); 697*4d537eeaSYohann ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d, 698*4d537eeaSYohann ncomp*elemsPerBlock, sharedMem, 699074be161SYohann Dudouit gradargs); 700074be161SYohann Dudouit CeedChk(ierr); 701074be161SYohann Dudouit } else if (dim==3) { 7023f63d318SYohann Dudouit CeedInt elemsPerBlock = 1; 703*4d537eeaSYohann CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 704*4d537eeaSYohann ? 1 : 0 ); 705698ebc35SYohann Dudouit CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar); 706*4d537eeaSYohann ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d, 707*4d537eeaSYohann ncomp*elemsPerBlock, sharedMem, 708074be161SYohann Dudouit gradargs); 709074be161SYohann Dudouit CeedChk(ierr); 710074be161SYohann Dudouit } 711c532df63SYohann } else if (emode == CEED_EVAL_WEIGHT) { 712074be161SYohann Dudouit CeedInt Q1d; 713074be161SYohann Dudouit ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr); 714c532df63SYohann void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v}; 715074be161SYohann Dudouit if(dim==1) { 716074be161SYohann Dudouit const CeedInt elemsPerBlock = 32/Q1d; 717*4d537eeaSYohann const CeedInt gridsize = nelem/elemsPerBlock + ( ( 718*4d537eeaSYohann nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 719*4d537eeaSYohann ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, elemsPerBlock, 1, 720*4d537eeaSYohann weightargs); 7211226057fSYohann Dudouit CeedChk(ierr); 722074be161SYohann Dudouit } else if(dim==2) { 723717ff8a3SYohann Dudouit const CeedInt optElems = 32/(Q1d*Q1d); 724717ff8a3SYohann Dudouit const CeedInt elemsPerBlock = optElems>0?optElems:1; 725*4d537eeaSYohann const CeedInt gridsize = nelem/elemsPerBlock + ( ( 726*4d537eeaSYohann nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 727*4d537eeaSYohann ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, 728*4d537eeaSYohann elemsPerBlock, weightargs); 7291226057fSYohann Dudouit CeedChk(ierr); 730074be161SYohann Dudouit } else if(dim==3) { 731074be161SYohann Dudouit const CeedInt gridsize = nelem; 732*4d537eeaSYohann ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, Q1d, 733*4d537eeaSYohann weightargs); 7341226057fSYohann Dudouit CeedChk(ierr); 735074be161SYohann Dudouit } 736c532df63SYohann } 737c532df63SYohann 738c532df63SYohann if(emode!=CEED_EVAL_WEIGHT) { 739c532df63SYohann ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr); 740c532df63SYohann } 741c532df63SYohann ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr); 742c532df63SYohann 743c532df63SYohann return 0; 744c532df63SYohann } 745c532df63SYohann 746c532df63SYohann static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) { 747c532df63SYohann int ierr; 748c532df63SYohann Ceed ceed; 749c532df63SYohann ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr); 750c532df63SYohann 751c532df63SYohann CeedBasis_Cuda_shared *data; 752c532df63SYohann ierr = CeedBasisGetData(basis, (void *) &data); CeedChk(ierr); 753c532df63SYohann 754c532df63SYohann CeedChk_Cu(ceed, cuModuleUnload(data->module)); 755c532df63SYohann 756c532df63SYohann ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr); 757c532df63SYohann ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr); 758c532df63SYohann ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr); 759c532df63SYohann 760c532df63SYohann ierr = CeedFree(&data); CeedChk(ierr); 761c532df63SYohann 762c532df63SYohann return 0; 763c532df63SYohann } 764c532df63SYohann 765c532df63SYohann int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d, 766c532df63SYohann const CeedScalar *interp1d, 767c532df63SYohann const CeedScalar *grad1d, 768c532df63SYohann const CeedScalar *qref1d, 769c532df63SYohann const CeedScalar *qweight1d, 770c532df63SYohann CeedBasis basis) { 771c532df63SYohann int ierr; 772c532df63SYohann Ceed ceed; 773c532df63SYohann ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr); 774*4d537eeaSYohann if (Q1d<P1d) { 7751226057fSYohann Dudouit return CeedError(ceed, 1, "Backend does not implement underintegrated basis."); 7761226057fSYohann Dudouit } 777c532df63SYohann CeedBasis_Cuda_shared *data; 778c532df63SYohann ierr = CeedCalloc(1, &data); CeedChk(ierr); 779c532df63SYohann 780c532df63SYohann const CeedInt qBytes = Q1d * sizeof(CeedScalar); 781c532df63SYohann ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr); 782c532df63SYohann ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes, 783c532df63SYohann cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr); 784c532df63SYohann 785c532df63SYohann const CeedInt iBytes = qBytes * P1d; 786c532df63SYohann ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr); 787c532df63SYohann ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes, 788c532df63SYohann cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr); 789c532df63SYohann 790c532df63SYohann ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr); 791c532df63SYohann ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes, 792c532df63SYohann cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr); 793c532df63SYohann 794c532df63SYohann CeedInt ncomp; 795c532df63SYohann ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr); 7964a6d4bbdSYohann Dudouit ierr = CeedCompileCuda(ceed, kernelsShared, &data->module, 7, 797c532df63SYohann "Q1D", Q1d, 798c532df63SYohann "P1D", P1d, 799c532df63SYohann "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ? 800c532df63SYohann Q1d : P1d, dim), 801c532df63SYohann "BASIS_DIM", dim, 802c532df63SYohann "BASIS_NCOMP", ncomp, 803c532df63SYohann "BASIS_ELEMSIZE", CeedIntPow(P1d, dim), 804c532df63SYohann "BASIS_NQPT", CeedIntPow(Q1d, dim) 805c532df63SYohann ); CeedChk(ierr); 8064a6d4bbdSYohann Dudouit ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp); 807c532df63SYohann CeedChk(ierr); 8084a6d4bbdSYohann Dudouit ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad); 809c532df63SYohann CeedChk(ierr); 8104a6d4bbdSYohann Dudouit ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight); 811c532df63SYohann CeedChk(ierr); 812c532df63SYohann 813c532df63SYohann ierr = CeedBasisSetData(basis, (void *)&data); 814c532df63SYohann CeedChk(ierr); 815c532df63SYohann ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply", 816c532df63SYohann CeedBasisApplyTensor_Cuda_shared); 817c532df63SYohann CeedChk(ierr); 818c532df63SYohann ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", 819c532df63SYohann CeedBasisDestroy_Cuda_shared); 820c532df63SYohann CeedChk(ierr); 821c532df63SYohann return 0; 822c532df63SYohann } 823