1c532df63SYohann // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC. 2c532df63SYohann // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707. 3c532df63SYohann // All Rights reserved. See files LICENSE and NOTICE for details. 4c532df63SYohann // 5c532df63SYohann // This file is part of CEED, a collection of benchmarks, miniapps, software 6c532df63SYohann // libraries and APIs for efficient high-order finite element and spectral 7c532df63SYohann // element discretizations for exascale applications. For more information and 8c532df63SYohann // source code availability see http://github.com/ceed. 9c532df63SYohann // 10c532df63SYohann // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC, 11c532df63SYohann // a collaborative effort of two U.S. Department of Energy organizations (Office 12c532df63SYohann // of Science and the National Nuclear Security Administration) responsible for 13c532df63SYohann // the planning and preparation of a capable exascale ecosystem, including 14c532df63SYohann // software, applications, hardware, advanced system engineering and early 15c532df63SYohann // testbed platforms, in support of the nation's exascale computing imperative. 16c532df63SYohann 17c532df63SYohann #include <ceed-backend.h> 18c532df63SYohann #include <ceed.h> 19c532df63SYohann #include "ceed-cuda-shared.h" 20c532df63SYohann #include "../cuda/ceed-cuda.h" 21c532df63SYohann 22c532df63SYohann //********************* 23c532df63SYohann // shared mem kernels 24cb0b5415Sjeremylt // *INDENT-OFF* 25c532df63SYohann static const char *kernelsShared = QUOTE( 26c532df63SYohann 27c532df63SYohann inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) { 28c532df63SYohann for (int i = 0; i < Q1D; i++) 29c532df63SYohann r_V[i] += r_U[i]; 30c532df63SYohann } 31c532df63SYohann 32c532df63SYohann ////////// 33c532df63SYohann // 1D // 34c532df63SYohann ////////// 35c532df63SYohann 36c532df63SYohann inline __device__ void readDofs1d(const int elem, const int tidx, 37d94769d2SYohann Dudouit const int tidy, const int tidz,const int comp, 387f823360Sjeremylt const int nelem, const CeedScalar *d_U, 397f823360Sjeremylt CeedScalar *slice) { 40c532df63SYohann for (int i = 0; i < P1D; i++) 41d94769d2SYohann Dudouit slice[i+tidz*Q1D] = d_U[i + comp*P1D + elem*BASIS_NCOMP*P1D]; 42c532df63SYohann for (int i = P1D; i < Q1D; i++) 43d94769d2SYohann Dudouit slice[i+tidz*Q1D] = 0.0; 44c532df63SYohann } 45c532df63SYohann 46c532df63SYohann inline __device__ void writeDofs1d(const int elem, const int tidx, 47c532df63SYohann const int tidy, const int comp, 48288c0443SJeremy L Thompson const int nelem, const CeedScalar &r_V, 49288c0443SJeremy L Thompson CeedScalar *d_V) { 50c532df63SYohann if (tidx<P1D) { 51c532df63SYohann d_V[tidx + comp*P1D + elem*BASIS_NCOMP*P1D] = r_V; 52c532df63SYohann } 53c532df63SYohann } 54c532df63SYohann 55c532df63SYohann inline __device__ void readQuads1d(const int elem, const int tidx, 56d94769d2SYohann Dudouit const int tidy, const int tidz, const int comp, 57288c0443SJeremy L Thompson const int dim, const int nelem, 58288c0443SJeremy L Thompson const CeedScalar *d_U, CeedScalar *slice) { 59c532df63SYohann for (int i = 0; i < Q1D; i++) 604d537eeaSYohann slice[i+tidz*Q1D] = d_U[i + elem*Q1D + comp*Q1D*nelem + 614d537eeaSYohann dim*BASIS_NCOMP*nelem*Q1D]; 62c532df63SYohann } 63c532df63SYohann 64c532df63SYohann inline __device__ void writeQuads1d(const int elem, const int tidx, 65c532df63SYohann const int tidy, const int comp, 66288c0443SJeremy L Thompson const int dim, const int nelem, 67288c0443SJeremy L Thompson const CeedScalar &r_V, CeedScalar *d_V) { 68c532df63SYohann d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V; 69c532df63SYohann } 70c532df63SYohann 71c532df63SYohann inline __device__ void ContractX1d(CeedScalar *slice, const int tidx, 72d94769d2SYohann Dudouit const int tidy, const int tidz, 73288c0443SJeremy L Thompson const CeedScalar &U, const CeedScalar *B, 74288c0443SJeremy L Thompson CeedScalar &V) { 75c532df63SYohann V = 0.0; 76c532df63SYohann for (int i = 0; i < P1D; ++i) { 77d94769d2SYohann Dudouit V += B[i + tidx*P1D] * slice[i+tidz*Q1D];//contract x direction 78c532df63SYohann } 79c532df63SYohann } 80c532df63SYohann 81c532df63SYohann inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx, 82d94769d2SYohann Dudouit const int tidy, const int tidz, 83c532df63SYohann const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 84c532df63SYohann V = 0.0; 85c532df63SYohann for (int i = 0; i < Q1D; ++i) { 86d94769d2SYohann Dudouit V += B[tidx + i*P1D] * slice[i+tidz*Q1D];//contract x direction 87c532df63SYohann } 88c532df63SYohann } 89c532df63SYohann 90c532df63SYohann inline __device__ void interp1d(const CeedInt nelem, const int transpose, 91288c0443SJeremy L Thompson const CeedScalar *c_B, 92288c0443SJeremy L Thompson const CeedScalar *__restrict__ d_U, 93c532df63SYohann CeedScalar *__restrict__ d_V, 94c532df63SYohann CeedScalar *slice) { 95c532df63SYohann CeedScalar r_V; 96c532df63SYohann CeedScalar r_t; 97c532df63SYohann 98c532df63SYohann const int tidx = threadIdx.x; 99c532df63SYohann const int tidy = threadIdx.y; 100d94769d2SYohann Dudouit const int tidz = threadIdx.z; 101c532df63SYohann 102c532df63SYohann 103c532df63SYohann for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 104c532df63SYohann elem += gridDim.x*blockDim.z) { 105c532df63SYohann for(int comp=0; comp<BASIS_NCOMP; comp++) { 106c532df63SYohann if(!transpose) { 107d94769d2SYohann Dudouit readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 108d94769d2SYohann Dudouit ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 109c532df63SYohann writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 110c532df63SYohann } else { 111d94769d2SYohann Dudouit readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice); 112d94769d2SYohann Dudouit ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 113c532df63SYohann writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 114c532df63SYohann } 115c532df63SYohann } 116c532df63SYohann } 117c532df63SYohann } 118c532df63SYohann 119c532df63SYohann inline __device__ void grad1d(const CeedInt nelem, const int transpose, 120c532df63SYohann const CeedScalar *c_B, const CeedScalar *c_G, 121288c0443SJeremy L Thompson const CeedScalar *__restrict__ d_U, 122288c0443SJeremy L Thompson CeedScalar *__restrict__ d_V, 123c532df63SYohann CeedScalar *slice) { 124c532df63SYohann CeedScalar r_U; 125c532df63SYohann CeedScalar r_V; 126c532df63SYohann 127c532df63SYohann const int tidx = threadIdx.x; 128d94769d2SYohann Dudouit const int tidy = threadIdx.y; 129d94769d2SYohann Dudouit const int tidz = threadIdx.z; 130c532df63SYohann int dim; 131c532df63SYohann 132c532df63SYohann for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 133c532df63SYohann elem += gridDim.x*blockDim.z) { 134c532df63SYohann for(int comp=0; comp<BASIS_NCOMP; comp++) { 135c532df63SYohann if(!transpose) { 136d94769d2SYohann Dudouit readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice); 137d94769d2SYohann Dudouit ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 138c532df63SYohann dim = 0; 139c532df63SYohann writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 140c532df63SYohann } else { 141c532df63SYohann dim = 0; 142d94769d2SYohann Dudouit readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice); 143d94769d2SYohann Dudouit ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 144c532df63SYohann writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V); 145c532df63SYohann } 146c532df63SYohann } 147c532df63SYohann } 148c532df63SYohann } 149c532df63SYohann ////////// 150c532df63SYohann // 2D // 151c532df63SYohann ////////// 152c532df63SYohann 153c532df63SYohann inline __device__ void readDofs2d(const int elem, const int tidx, 154c532df63SYohann const int tidy, const int comp, 155288c0443SJeremy L Thompson const int nelem, const CeedScalar *d_U, 156288c0443SJeremy L Thompson CeedScalar &U) { 157c532df63SYohann U = (tidx<P1D 1587f823360Sjeremylt && tidy<P1D) ? d_U[tidx + tidy*P1D + comp*P1D*P1D + 1597f823360Sjeremylt elem*BASIS_NCOMP*P1D*P1D ] : 160c532df63SYohann 0.0; 161c532df63SYohann } 162c532df63SYohann 163c532df63SYohann inline __device__ void writeDofs2d(const int elem, const int tidx, 164c532df63SYohann const int tidy, const int comp, 165288c0443SJeremy L Thompson const int nelem, const CeedScalar &r_V, 166288c0443SJeremy L Thompson CeedScalar *d_V) { 167c532df63SYohann if (tidx<P1D && tidy<P1D) { 168c532df63SYohann d_V[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D ] = r_V; 169c532df63SYohann } 170c532df63SYohann } 171c532df63SYohann 172c532df63SYohann inline __device__ void readQuads2d(const int elem, const int tidx, 173c532df63SYohann const int tidy, const int comp, 174288c0443SJeremy L Thompson const int dim, const int nelem, 175288c0443SJeremy L Thompson const CeedScalar *d_U, CeedScalar &U ) { 176c532df63SYohann U = d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 177c532df63SYohann dim*BASIS_NCOMP*nelem*Q1D*Q1D]; 178c532df63SYohann } 179c532df63SYohann 180c532df63SYohann inline __device__ void writeQuads2d(const int elem, const int tidx, 181c532df63SYohann const int tidy, const int comp, 182288c0443SJeremy L Thompson const int dim, const int nelem, 183288c0443SJeremy L Thompson const CeedScalar &r_V, CeedScalar *d_V) { 184c532df63SYohann d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem + 185c532df63SYohann dim*BASIS_NCOMP*nelem*Q1D*Q1D ] = r_V; 186c532df63SYohann } 187c532df63SYohann 188c532df63SYohann inline __device__ void ContractX2d(CeedScalar *slice, const int tidx, 1894247ecf3SYohann Dudouit const int tidy, const int tidz, 190288c0443SJeremy L Thompson const CeedScalar &U, const CeedScalar *B, 191288c0443SJeremy L Thompson CeedScalar &V) { 1924247ecf3SYohann Dudouit slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U; 193c532df63SYohann __syncthreads(); 194c532df63SYohann V = 0.0; 195c532df63SYohann for (int i = 0; i < P1D; ++i) { 1964247ecf3SYohann Dudouit V += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction 197c532df63SYohann } 198c532df63SYohann __syncthreads(); 199c532df63SYohann } 200c532df63SYohann 201c532df63SYohann inline __device__ void ContractY2d(CeedScalar *slice, const int tidx, 2024247ecf3SYohann Dudouit const int tidy, const int tidz, 203288c0443SJeremy L Thompson const CeedScalar &U, const CeedScalar *B, 204288c0443SJeremy L Thompson CeedScalar &V) { 2054247ecf3SYohann Dudouit slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U; 206c532df63SYohann __syncthreads(); 207c532df63SYohann V = 0.0; 208c532df63SYohann for (int i = 0; i < P1D; ++i) { 2094247ecf3SYohann Dudouit V += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction 210c532df63SYohann } 211c532df63SYohann __syncthreads(); 212c532df63SYohann } 213c532df63SYohann 214c532df63SYohann inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx, 2154247ecf3SYohann Dudouit const int tidy, const int tidz, 216c532df63SYohann const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 2174247ecf3SYohann Dudouit slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U; 218c532df63SYohann __syncthreads(); 219c532df63SYohann V = 0.0; 220c532df63SYohann if (tidy<P1D) { 221c532df63SYohann for (int i = 0; i < Q1D; ++i) { 2224247ecf3SYohann Dudouit V += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction 223c532df63SYohann } 224c532df63SYohann } 225c532df63SYohann __syncthreads(); 226c532df63SYohann } 227c532df63SYohann 228c532df63SYohann inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx, 2294247ecf3SYohann Dudouit const int tidy, const int tidz, 230c532df63SYohann const CeedScalar &U, const CeedScalar *B, CeedScalar &V) { 2314247ecf3SYohann Dudouit slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U; 232c532df63SYohann __syncthreads(); 233c532df63SYohann V = 0.0; 234c532df63SYohann if (tidx<P1D) { 235c532df63SYohann for (int i = 0; i < Q1D; ++i) { 2364247ecf3SYohann Dudouit V += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction 237c532df63SYohann } 238c532df63SYohann } 239c532df63SYohann __syncthreads(); 240c532df63SYohann } 241c532df63SYohann 242c532df63SYohann inline __device__ void interp2d(const CeedInt nelem, const int transpose, 243288c0443SJeremy L Thompson const CeedScalar *c_B, 244288c0443SJeremy L Thompson const CeedScalar *__restrict__ d_U, 245c532df63SYohann CeedScalar *__restrict__ d_V, 246c532df63SYohann CeedScalar *slice) { 247c532df63SYohann CeedScalar r_V; 248c532df63SYohann CeedScalar r_t; 249c532df63SYohann 250c532df63SYohann const int tidx = threadIdx.x; 251c532df63SYohann const int tidy = threadIdx.y; 2524247ecf3SYohann Dudouit const int tidz = threadIdx.z; 2534247ecf3SYohann Dudouit const int blockElem = tidz/BASIS_NCOMP; 2544247ecf3SYohann Dudouit const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 2554247ecf3SYohann Dudouit const int comp = tidz%BASIS_NCOMP; 256c532df63SYohann 2574247ecf3SYohann Dudouit for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 2584247ecf3SYohann Dudouit elem += gridDim.x*elemsPerBlock) { 2594247ecf3SYohann Dudouit const int comp = tidz%BASIS_NCOMP; 260c532df63SYohann r_V = 0.0; 261c532df63SYohann r_t = 0.0; 262c532df63SYohann if(!transpose) { 263c532df63SYohann readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V); 2644247ecf3SYohann Dudouit ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 2654247ecf3SYohann Dudouit ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 266c532df63SYohann writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V); 267c532df63SYohann } else { 268c532df63SYohann readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 2694247ecf3SYohann Dudouit ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 2704247ecf3SYohann Dudouit ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 271c532df63SYohann writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 272c532df63SYohann } 273c532df63SYohann } 274c532df63SYohann } 275c532df63SYohann 276c532df63SYohann inline __device__ void grad2d(const CeedInt nelem, const int transpose, 277c532df63SYohann const CeedScalar *c_B, const CeedScalar *c_G, 2787f823360Sjeremylt const CeedScalar *__restrict__ d_U, 2797f823360Sjeremylt CeedScalar *__restrict__ d_V, CeedScalar *slice) { 280c532df63SYohann CeedScalar r_U; 281c532df63SYohann CeedScalar r_V; 282c532df63SYohann CeedScalar r_t; 283c532df63SYohann 284c532df63SYohann const int tidx = threadIdx.x; 285c532df63SYohann const int tidy = threadIdx.y; 2864247ecf3SYohann Dudouit const int tidz = threadIdx.z; 2874247ecf3SYohann Dudouit const int blockElem = tidz/BASIS_NCOMP; 2884247ecf3SYohann Dudouit const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 2894247ecf3SYohann Dudouit const int comp = tidz%BASIS_NCOMP; 290c532df63SYohann int dim; 291c532df63SYohann 2924247ecf3SYohann Dudouit for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 2934247ecf3SYohann Dudouit elem += gridDim.x*elemsPerBlock) { 294c532df63SYohann if(!transpose) { 295c532df63SYohann readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U); 2964247ecf3SYohann Dudouit ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 2974247ecf3SYohann Dudouit ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 298c532df63SYohann dim = 0; 299c532df63SYohann writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 3004247ecf3SYohann Dudouit ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 3014247ecf3SYohann Dudouit ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 302c532df63SYohann dim = 1; 303c532df63SYohann writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 304c532df63SYohann } else { 305c532df63SYohann dim = 0; 306c532df63SYohann readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 3074247ecf3SYohann Dudouit ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 3084247ecf3SYohann Dudouit ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 309c532df63SYohann dim = 1; 310c532df63SYohann readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 3114247ecf3SYohann Dudouit ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 3124247ecf3SYohann Dudouit ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 313c532df63SYohann r_V+=r_U; 314c532df63SYohann writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V); 315c532df63SYohann } 316c532df63SYohann } 317c532df63SYohann } 318c532df63SYohann ////////// 319c532df63SYohann // 3D // 320c532df63SYohann ////////// 321c532df63SYohann 322c532df63SYohann inline __device__ void readDofs3d(const int elem, const int tidx, 323c532df63SYohann const int tidy, const int comp, 3247f823360Sjeremylt const int nelem, const CeedScalar *d_U, 3257f823360Sjeremylt CeedScalar *r_U) { 326c532df63SYohann for (int i = 0; i < P1D; i++) 327c532df63SYohann r_U[i] = (tidx<P1D 328c532df63SYohann && tidy<P1D) ? d_U[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D + 329c532df63SYohann elem*BASIS_NCOMP*P1D*P1D*P1D ] : 0.0; 330c532df63SYohann for (int i = P1D; i < Q1D; i++) 331c532df63SYohann r_U[i] = 0.0; 332c532df63SYohann } 333c532df63SYohann 334c532df63SYohann inline __device__ void readQuads3d(const int elem, const int tidx, 335c532df63SYohann const int tidy, const int comp, 3367f823360Sjeremylt const int dim, const int nelem, 3377f823360Sjeremylt const CeedScalar *d_U, CeedScalar *r_U) { 338c532df63SYohann for (int i = 0; i < Q1D; i++) 339c532df63SYohann r_U[i] = d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + 340c532df63SYohann comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D]; 341c532df63SYohann } 342c532df63SYohann 343c532df63SYohann inline __device__ void writeDofs3d(const int elem, const int tidx, 344c532df63SYohann const int tidy, const int comp, 3457f823360Sjeremylt const int nelem, const CeedScalar *r_V, 3467f823360Sjeremylt CeedScalar *d_V) { 347c532df63SYohann if (tidx<P1D && tidy<P1D) { 348c532df63SYohann for (int i = 0; i < P1D; i++) 349c532df63SYohann d_V[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D + 350c532df63SYohann elem*BASIS_NCOMP*P1D*P1D*P1D ] = r_V[i]; 351c532df63SYohann } 352c532df63SYohann } 353c532df63SYohann 354c532df63SYohann inline __device__ void writeQuads3d(const int elem, const int tidx, 355c532df63SYohann const int tidy, const int comp, 3567f823360Sjeremylt const int dim, const int nelem, 3577f823360Sjeremylt const CeedScalar *r_V, CeedScalar *d_V) { 358c532df63SYohann for (int i = 0; i < Q1D; i++) 359c532df63SYohann d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem + 360c532df63SYohann dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D ] = r_V[i]; 361c532df63SYohann } 362c532df63SYohann 363c532df63SYohann inline __device__ void ContractX3d(CeedScalar *slice, const int tidx, 364698ebc35SYohann Dudouit const int tidy, const int tidz, 3657f823360Sjeremylt const CeedScalar *U, const CeedScalar *B, 3667f823360Sjeremylt CeedScalar *V) { 367c532df63SYohann for (int k = 0; k < P1D; ++k) { 368698ebc35SYohann Dudouit slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k]; 369c532df63SYohann __syncthreads(); 370c532df63SYohann V[k] = 0.0; 371c532df63SYohann for (int i = 0; i < P1D; ++i) { 3724d537eeaSYohann V[k] += B[i + tidx*P1D] * slice[i + tidy*Q1D + 3734d537eeaSYohann tidz*Q1D*Q1D];//contract x direction 374c532df63SYohann } 375c532df63SYohann __syncthreads(); 376c532df63SYohann } 377c532df63SYohann } 378c532df63SYohann 379c532df63SYohann inline __device__ void ContractY3d(CeedScalar *slice, const int tidx, 380698ebc35SYohann Dudouit const int tidy, const int tidz, 3817f823360Sjeremylt const CeedScalar *U, const CeedScalar *B, 3827f823360Sjeremylt CeedScalar *V) { 383c532df63SYohann for (int k = 0; k < P1D; ++k) { 384698ebc35SYohann Dudouit slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k]; 385c532df63SYohann __syncthreads(); 386c532df63SYohann V[k] = 0.0; 387c532df63SYohann for (int i = 0; i < P1D; ++i) { 3884d537eeaSYohann V[k] += B[i + tidy*P1D] * slice[tidx + i*Q1D + 3894d537eeaSYohann tidz*Q1D*Q1D];//contract y direction 390c532df63SYohann } 391c532df63SYohann __syncthreads(); 392c532df63SYohann } 393c532df63SYohann } 394c532df63SYohann 395c532df63SYohann inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx, 396698ebc35SYohann Dudouit const int tidy, const int tidz, 3977f823360Sjeremylt const CeedScalar *U, const CeedScalar *B, 3987f823360Sjeremylt CeedScalar *V) { 399c532df63SYohann for (int k = 0; k < Q1D; ++k) { 400c532df63SYohann V[k] = 0.0; 401c532df63SYohann for (int i = 0; i < P1D; ++i) { 402c532df63SYohann V[k] += B[i + k*P1D] * U[i];//contract z direction 403c532df63SYohann } 404c532df63SYohann } 405c532df63SYohann } 406c532df63SYohann 407c532df63SYohann inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx, 408698ebc35SYohann Dudouit const int tidy, const int tidz, 409c532df63SYohann const CeedScalar *U, const CeedScalar *B, CeedScalar *V) { 410c532df63SYohann for (int k = 0; k < Q1D; ++k) { 411c532df63SYohann V[k] = 0.0; 412c532df63SYohann if (k<P1D) { 413c532df63SYohann for (int i = 0; i < Q1D; ++i) { 414c532df63SYohann V[k] += B[k + i*P1D] * U[i];//contract z direction 415c532df63SYohann } 416c532df63SYohann } 417c532df63SYohann } 418c532df63SYohann } 419c532df63SYohann 420c532df63SYohann inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx, 421698ebc35SYohann Dudouit const int tidy, const int tidz, 422c532df63SYohann const CeedScalar *U, const CeedScalar *B, CeedScalar *V) { 423c532df63SYohann for (int k = 0; k < P1D; ++k) { 424698ebc35SYohann Dudouit slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k]; 425c532df63SYohann __syncthreads(); 426c532df63SYohann V[k] = 0.0; 427c532df63SYohann if (tidy<P1D) { 428c532df63SYohann for (int i = 0; i < Q1D; ++i) { 4294d537eeaSYohann V[k] += B[tidy + i*P1D] * slice[tidx + i*Q1D + 4304d537eeaSYohann tidz*Q1D*Q1D];//contract y direction 431c532df63SYohann } 432c532df63SYohann } 433c532df63SYohann __syncthreads(); 434c532df63SYohann } 435c532df63SYohann } 436c532df63SYohann 437c532df63SYohann inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx, 438698ebc35SYohann Dudouit const int tidy, const int tidz, 439c532df63SYohann const CeedScalar *U, const CeedScalar *B, CeedScalar *V) { 440c532df63SYohann for (int k = 0; k < P1D; ++k) { 441698ebc35SYohann Dudouit slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k]; 442c532df63SYohann __syncthreads(); 443c532df63SYohann V[k] = 0.0; 444c532df63SYohann if (tidx<P1D) { 445c532df63SYohann for (int i = 0; i < Q1D; ++i) { 4464d537eeaSYohann V[k] += B[tidx + i*P1D] * slice[i + tidy*Q1D + 4474d537eeaSYohann tidz*Q1D*Q1D];//contract x direction 448c532df63SYohann } 449c532df63SYohann } 450c532df63SYohann __syncthreads(); 451c532df63SYohann } 452c532df63SYohann } 453c532df63SYohann 454c532df63SYohann inline __device__ void interp3d(const CeedInt nelem, const int transpose, 4557f823360Sjeremylt const CeedScalar *c_B, 4567f823360Sjeremylt const CeedScalar *__restrict__ d_U, 457c532df63SYohann CeedScalar *__restrict__ d_V, 458c532df63SYohann CeedScalar *slice) { 459c532df63SYohann CeedScalar r_V[Q1D]; 460c532df63SYohann CeedScalar r_t[Q1D]; 461c532df63SYohann 462c532df63SYohann const int tidx = threadIdx.x; 463c532df63SYohann const int tidy = threadIdx.y; 464698ebc35SYohann Dudouit const int tidz = threadIdx.z; 465698ebc35SYohann Dudouit const int blockElem = tidz/BASIS_NCOMP; 466698ebc35SYohann Dudouit const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 467698ebc35SYohann Dudouit const int comp = tidz%BASIS_NCOMP; 468c532df63SYohann 469698ebc35SYohann Dudouit for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 470698ebc35SYohann Dudouit elem += gridDim.x*elemsPerBlock) { 471c532df63SYohann for (int i = 0; i < Q1D; ++i) { 472c532df63SYohann r_V[i] = 0.0; 473c532df63SYohann r_t[i] = 0.0; 474c532df63SYohann } 475c532df63SYohann if(!transpose) { 476c532df63SYohann readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V); 477698ebc35SYohann Dudouit ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 478698ebc35SYohann Dudouit ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 479698ebc35SYohann Dudouit ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 480c532df63SYohann writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V); 481c532df63SYohann } else { 482c532df63SYohann readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V); 483698ebc35SYohann Dudouit ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 484698ebc35SYohann Dudouit ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 485698ebc35SYohann Dudouit ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 486c532df63SYohann writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V); 487c532df63SYohann } 488c532df63SYohann } 489c532df63SYohann } 490c532df63SYohann 491c532df63SYohann inline __device__ void grad3d(const CeedInt nelem, const int transpose, 492c532df63SYohann const CeedScalar *c_B, const CeedScalar *c_G, 4937f823360Sjeremylt const CeedScalar *__restrict__ d_U, 4947f823360Sjeremylt CeedScalar *__restrict__ d_V, 495c532df63SYohann CeedScalar *slice) { 496c532df63SYohann //use P1D for one of these 497c532df63SYohann CeedScalar r_U[Q1D]; 498c532df63SYohann CeedScalar r_V[Q1D]; 499c532df63SYohann CeedScalar r_t[Q1D]; 500c532df63SYohann 501c532df63SYohann const int tidx = threadIdx.x; 502c532df63SYohann const int tidy = threadIdx.y; 503698ebc35SYohann Dudouit const int tidz = threadIdx.z; 504698ebc35SYohann Dudouit const int blockElem = tidz/BASIS_NCOMP; 505698ebc35SYohann Dudouit const int elemsPerBlock = blockDim.z/BASIS_NCOMP; 506698ebc35SYohann Dudouit const int comp = tidz%BASIS_NCOMP; 507c532df63SYohann int dim; 508c532df63SYohann 509698ebc35SYohann Dudouit for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem; 510698ebc35SYohann Dudouit elem += gridDim.x*elemsPerBlock) { 511c532df63SYohann if(!transpose) { 512c532df63SYohann readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U); 513698ebc35SYohann Dudouit ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 514698ebc35SYohann Dudouit ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 515698ebc35SYohann Dudouit ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 516c532df63SYohann dim = 0; 517c532df63SYohann writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 518698ebc35SYohann Dudouit ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V); 519698ebc35SYohann Dudouit ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t); 520698ebc35SYohann Dudouit ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V); 521c532df63SYohann dim = 1; 522c532df63SYohann writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 523698ebc35SYohann Dudouit ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V); 524698ebc35SYohann Dudouit ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t); 525698ebc35SYohann Dudouit ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V); 526c532df63SYohann dim = 2; 527c532df63SYohann writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V); 528c532df63SYohann } else { 529c532df63SYohann dim = 0; 530c532df63SYohann readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 531698ebc35SYohann Dudouit ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 532698ebc35SYohann Dudouit ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 533698ebc35SYohann Dudouit ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V); 534c532df63SYohann dim = 1; 535c532df63SYohann readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 536698ebc35SYohann Dudouit ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 537698ebc35SYohann Dudouit ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U); 538698ebc35SYohann Dudouit ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 539c532df63SYohann add(r_V, r_t); 540c532df63SYohann dim = 2; 541c532df63SYohann readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U); 542698ebc35SYohann Dudouit ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t); 543698ebc35SYohann Dudouit ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U); 544698ebc35SYohann Dudouit ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t); 545c532df63SYohann add(r_V, r_t); 546c532df63SYohann writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V); 547c532df63SYohann } 548c532df63SYohann } 549c532df63SYohann } 550c532df63SYohann 551c532df63SYohann ///////////// 552c532df63SYohann // Kernels // 553c532df63SYohann ///////////// 554c532df63SYohann extern "C" __global__ void interp(const CeedInt nelem, const int transpose, 5557f823360Sjeremylt const CeedScalar *c_B, 5567f823360Sjeremylt const CeedScalar *__restrict__ d_U, 557c532df63SYohann CeedScalar *__restrict__ d_V) { 558074be161SYohann Dudouit extern __shared__ double slice[]; 559c532df63SYohann if (BASIS_DIM==1) { 560c532df63SYohann interp1d(nelem, transpose, c_B, d_U, d_V, slice); 561c532df63SYohann } else if (BASIS_DIM==2) { 562c532df63SYohann interp2d(nelem, transpose, c_B, d_U, d_V, slice); 563c532df63SYohann } else if (BASIS_DIM==3) { 564c532df63SYohann interp3d(nelem, transpose, c_B, d_U, d_V, slice); 565c532df63SYohann } 566c532df63SYohann } 567c532df63SYohann 568c532df63SYohann extern "C" __global__ void grad(const CeedInt nelem, const int transpose, 569c532df63SYohann const CeedScalar *c_B, const CeedScalar *c_G, 5707f823360Sjeremylt const CeedScalar *__restrict__ d_U, 5717f823360Sjeremylt CeedScalar *__restrict__ d_V) { 572074be161SYohann Dudouit extern __shared__ double slice[]; 573c532df63SYohann if (BASIS_DIM==1) { 574c532df63SYohann grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 575c532df63SYohann } else if (BASIS_DIM==2) { 576c532df63SYohann grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 577c532df63SYohann } else if (BASIS_DIM==3) { 578c532df63SYohann grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice); 579c532df63SYohann } 580c532df63SYohann } 581c532df63SYohann 582c532df63SYohann ///////////// 583c532df63SYohann // Weights // 584c532df63SYohann ///////////// 585c532df63SYohann __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d, 586c532df63SYohann CeedScalar *w) { 587074be161SYohann Dudouit const int tid = threadIdx.x; 588074be161SYohann Dudouit const CeedScalar weight = qweight1d[tid]; 589074be161SYohann Dudouit for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem; 590074be161SYohann Dudouit elem += gridDim.x*blockDim.y) { 591074be161SYohann Dudouit const int ind = elem*Q1D + tid; 592074be161SYohann Dudouit w[ind] = weight; 593c532df63SYohann } 594c532df63SYohann } 595c532df63SYohann 596c532df63SYohann __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d, 597c532df63SYohann CeedScalar *w) { 598074be161SYohann Dudouit const int i = threadIdx.x; 599074be161SYohann Dudouit const int j = threadIdx.y; 600074be161SYohann Dudouit const CeedScalar weight = qweight1d[i]*qweight1d[j]; 601074be161SYohann Dudouit for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem; 602074be161SYohann Dudouit elem += gridDim.x*blockDim.z) { 603074be161SYohann Dudouit const int ind = elem*Q1D*Q1D + i + j*Q1D; 604074be161SYohann Dudouit w[ind] = weight; 605c532df63SYohann } 606c532df63SYohann } 607c532df63SYohann 608c532df63SYohann __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d, 609c532df63SYohann CeedScalar *w) { 610074be161SYohann Dudouit const int i = threadIdx.x; 611074be161SYohann Dudouit const int j = threadIdx.y; 612074be161SYohann Dudouit const int k = threadIdx.z; 613074be161SYohann Dudouit const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k]; 614074be161SYohann Dudouit for (int e = blockIdx.x; e < nelem; e += gridDim.x) { 615074be161SYohann Dudouit const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D; 616074be161SYohann Dudouit w[ind] = weight; 617c532df63SYohann } 618c532df63SYohann } 619c532df63SYohann 620c532df63SYohann extern "C" __global__ void weight(const CeedInt nelem, 6217f823360Sjeremylt const CeedScalar *__restrict__ qweight1d, 6227f823360Sjeremylt CeedScalar *__restrict__ v) { 623c532df63SYohann if (BASIS_DIM==1) { 624c532df63SYohann weight1d(nelem, qweight1d, v); 625c532df63SYohann } else if (BASIS_DIM==2) { 626c532df63SYohann weight2d(nelem, qweight1d, v); 627c532df63SYohann } else if (BASIS_DIM==3) { 628c532df63SYohann weight3d(nelem, qweight1d, v); 629c532df63SYohann } 630c532df63SYohann } 631c532df63SYohann 632c532df63SYohann ); 633cb0b5415Sjeremylt // *INDENT-ON* 634c532df63SYohann 635c532df63SYohann int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d, 636c532df63SYohann CeedScalar **c_B); 637c532df63SYohann int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d, 6387f823360Sjeremylt CeedInt Q1d, CeedScalar **c_B_ptr, 6397f823360Sjeremylt CeedScalar **c_G_ptr); 640c532df63SYohann 641c532df63SYohann int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem, 642c532df63SYohann CeedTransposeMode tmode, 6437f823360Sjeremylt CeedEvalMode emode, CeedVector u, 6447f823360Sjeremylt CeedVector v) { 645c532df63SYohann int ierr; 646c532df63SYohann Ceed ceed; 647c532df63SYohann ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr); 648c532df63SYohann Ceed_Cuda_shared *ceed_Cuda; 649c532df63SYohann CeedGetData(ceed, (void *) &ceed_Cuda); CeedChk(ierr); 650c532df63SYohann CeedBasis_Cuda_shared *data; 651c532df63SYohann CeedBasisGetData(basis, (void *)&data); CeedChk(ierr); 652c532df63SYohann const CeedInt transpose = tmode == CEED_TRANSPOSE; 6534247ecf3SYohann Dudouit CeedInt dim, ncomp; 654074be161SYohann Dudouit ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr); 6554247ecf3SYohann Dudouit ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr); 656c532df63SYohann 657c532df63SYohann const CeedScalar *d_u; 658c532df63SYohann CeedScalar *d_v; 659c532df63SYohann if(emode!=CEED_EVAL_WEIGHT) { 660c532df63SYohann ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr); 661c532df63SYohann } 662c532df63SYohann ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr); 663c532df63SYohann 664c532df63SYohann if (tmode == CEED_TRANSPOSE) { 665c532df63SYohann CeedInt length; 666c532df63SYohann ierr = CeedVectorGetLength(v, &length); CeedChk(ierr); 667c532df63SYohann ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr); 668c532df63SYohann } 669c532df63SYohann if (emode == CEED_EVAL_INTERP) { 670c532df63SYohann CeedInt P1d, Q1d; 671c532df63SYohann ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr); 672c532df63SYohann ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr); 673c532df63SYohann ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B); 674c532df63SYohann CeedChk(ierr); 675cb0b5415Sjeremylt void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, 676*ccf0fe6fSjeremylt &d_u, &d_v 677*ccf0fe6fSjeremylt }; 6784d537eeaSYohann if (dim==1) { 679d94769d2SYohann Dudouit CeedInt elemsPerBlock = 32; 6804d537eeaSYohann CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 6814d537eeaSYohann ? 1 : 0 ); 682d94769d2SYohann Dudouit CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar); 6834d537eeaSYohann ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, 1, 6844d537eeaSYohann elemsPerBlock, sharedMem, 685c532df63SYohann interpargs); 686c532df63SYohann CeedChk(ierr); 687074be161SYohann Dudouit } else if (dim==2) { 6884247ecf3SYohann Dudouit const CeedInt optElems[7] = {0,32,8,6,4,2,8}; 6894247ecf3SYohann Dudouit CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1; 6904d537eeaSYohann CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 6914d537eeaSYohann ? 1 : 0 ); 6924247ecf3SYohann Dudouit CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar); 6934d537eeaSYohann ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d, 6944d537eeaSYohann ncomp*elemsPerBlock, sharedMem, 695074be161SYohann Dudouit interpargs); 696074be161SYohann Dudouit CeedChk(ierr); 697074be161SYohann Dudouit } else if (dim==3) { 6983f63d318SYohann Dudouit CeedInt elemsPerBlock = 1; 6994d537eeaSYohann CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 7004d537eeaSYohann ? 1 : 0 ); 701698ebc35SYohann Dudouit CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar); 7024d537eeaSYohann ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d, 7034d537eeaSYohann ncomp*elemsPerBlock, sharedMem, 704074be161SYohann Dudouit interpargs); 705074be161SYohann Dudouit CeedChk(ierr); 706074be161SYohann Dudouit } 707c532df63SYohann } else if (emode == CEED_EVAL_GRAD) { 708c532df63SYohann CeedInt P1d, Q1d; 709c532df63SYohann ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr); 710c532df63SYohann ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr); 711c532df63SYohann ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d, 712c532df63SYohann Q1d, &data->c_B, &data->c_G); 713c532df63SYohann CeedChk(ierr); 714cb0b5415Sjeremylt void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, 715*ccf0fe6fSjeremylt &data->c_G, &d_u, &d_v 716*ccf0fe6fSjeremylt }; 7174d537eeaSYohann if (dim==1) { 718d94769d2SYohann Dudouit CeedInt elemsPerBlock = 32; 7194d537eeaSYohann CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 7204d537eeaSYohann ? 1 : 0 ); 721d94769d2SYohann Dudouit CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar); 7224d537eeaSYohann ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, 1, elemsPerBlock, 7234d537eeaSYohann sharedMem, 724c532df63SYohann gradargs); 725c532df63SYohann CeedChk(ierr); 726074be161SYohann Dudouit } else if (dim==2) { 7274247ecf3SYohann Dudouit const CeedInt optElems[7] = {0,32,8,6,4,2,8}; 7284247ecf3SYohann Dudouit CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1; 7294d537eeaSYohann CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 7304d537eeaSYohann ? 1 : 0 ); 7314247ecf3SYohann Dudouit CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar); 7324d537eeaSYohann ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d, 7334d537eeaSYohann ncomp*elemsPerBlock, sharedMem, 734074be161SYohann Dudouit gradargs); 735074be161SYohann Dudouit CeedChk(ierr); 736074be161SYohann Dudouit } else if (dim==3) { 7373f63d318SYohann Dudouit CeedInt elemsPerBlock = 1; 7384d537eeaSYohann CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem) 7394d537eeaSYohann ? 1 : 0 ); 740698ebc35SYohann Dudouit CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar); 7414d537eeaSYohann ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d, 7424d537eeaSYohann ncomp*elemsPerBlock, sharedMem, 743074be161SYohann Dudouit gradargs); 744074be161SYohann Dudouit CeedChk(ierr); 745074be161SYohann Dudouit } 746c532df63SYohann } else if (emode == CEED_EVAL_WEIGHT) { 747074be161SYohann Dudouit CeedInt Q1d; 748074be161SYohann Dudouit ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr); 749c532df63SYohann void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v}; 750074be161SYohann Dudouit if(dim==1) { 751074be161SYohann Dudouit const CeedInt elemsPerBlock = 32/Q1d; 7524d537eeaSYohann const CeedInt gridsize = nelem/elemsPerBlock + ( ( 7534d537eeaSYohann nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 7547f823360Sjeremylt ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, 7557f823360Sjeremylt elemsPerBlock, 1, weightargs); 7561226057fSYohann Dudouit CeedChk(ierr); 757074be161SYohann Dudouit } else if(dim==2) { 758717ff8a3SYohann Dudouit const CeedInt optElems = 32/(Q1d*Q1d); 759717ff8a3SYohann Dudouit const CeedInt elemsPerBlock = optElems>0?optElems:1; 7604d537eeaSYohann const CeedInt gridsize = nelem/elemsPerBlock + ( ( 7614d537eeaSYohann nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 ); 7624d537eeaSYohann ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, 7634d537eeaSYohann elemsPerBlock, weightargs); 7641226057fSYohann Dudouit CeedChk(ierr); 765074be161SYohann Dudouit } else if(dim==3) { 766074be161SYohann Dudouit const CeedInt gridsize = nelem; 7674d537eeaSYohann ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, Q1d, 7684d537eeaSYohann weightargs); 7691226057fSYohann Dudouit CeedChk(ierr); 770074be161SYohann Dudouit } 771c532df63SYohann } 772c532df63SYohann 773c532df63SYohann if(emode!=CEED_EVAL_WEIGHT) { 774c532df63SYohann ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr); 775c532df63SYohann } 776c532df63SYohann ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr); 777c532df63SYohann 778c532df63SYohann return 0; 779c532df63SYohann } 780c532df63SYohann 781c532df63SYohann static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) { 782c532df63SYohann int ierr; 783c532df63SYohann Ceed ceed; 784c532df63SYohann ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr); 785c532df63SYohann 786c532df63SYohann CeedBasis_Cuda_shared *data; 787c532df63SYohann ierr = CeedBasisGetData(basis, (void *) &data); CeedChk(ierr); 788c532df63SYohann 789c532df63SYohann CeedChk_Cu(ceed, cuModuleUnload(data->module)); 790c532df63SYohann 791c532df63SYohann ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr); 792c532df63SYohann ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr); 793c532df63SYohann ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr); 794c532df63SYohann 795c532df63SYohann ierr = CeedFree(&data); CeedChk(ierr); 796c532df63SYohann 797c532df63SYohann return 0; 798c532df63SYohann } 799c532df63SYohann 800c532df63SYohann int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d, 801c532df63SYohann const CeedScalar *interp1d, 802c532df63SYohann const CeedScalar *grad1d, 803c532df63SYohann const CeedScalar *qref1d, 804c532df63SYohann const CeedScalar *qweight1d, 805c532df63SYohann CeedBasis basis) { 806c532df63SYohann int ierr; 807c532df63SYohann Ceed ceed; 808c532df63SYohann ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr); 8094d537eeaSYohann if (Q1d<P1d) { 8101226057fSYohann Dudouit return CeedError(ceed, 1, "Backend does not implement underintegrated basis."); 8111226057fSYohann Dudouit } 812c532df63SYohann CeedBasis_Cuda_shared *data; 813c532df63SYohann ierr = CeedCalloc(1, &data); CeedChk(ierr); 814c532df63SYohann 815c532df63SYohann const CeedInt qBytes = Q1d * sizeof(CeedScalar); 816c532df63SYohann ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr); 817c532df63SYohann ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes, 818c532df63SYohann cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr); 819c532df63SYohann 820c532df63SYohann const CeedInt iBytes = qBytes * P1d; 821c532df63SYohann ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr); 822c532df63SYohann ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes, 823c532df63SYohann cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr); 824c532df63SYohann 825c532df63SYohann ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr); 826c532df63SYohann ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes, 827c532df63SYohann cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr); 828c532df63SYohann 829ac421f39SYohann data->d_collograd1d = NULL; 830ac421f39SYohann if (dim==3 && Q1d >= P1d) { 831ac421f39SYohann CeedScalar *collograd1d; 832ac421f39SYohann ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChk(ierr); 833ac421f39SYohann ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChk(ierr); 834ac421f39SYohann ierr = cudaMalloc((void **)&data->d_collograd1d, qBytes * Q1d); 835ac421f39SYohann CeedChk_Cu(ceed, ierr); 836ac421f39SYohann ierr = cudaMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d, 837ac421f39SYohann cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr); 838ac421f39SYohann } 839ac421f39SYohann 840c532df63SYohann CeedInt ncomp; 841c532df63SYohann ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr); 8424a6d4bbdSYohann Dudouit ierr = CeedCompileCuda(ceed, kernelsShared, &data->module, 7, 843c532df63SYohann "Q1D", Q1d, 844c532df63SYohann "P1D", P1d, 845c532df63SYohann "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ? 846c532df63SYohann Q1d : P1d, dim), 847c532df63SYohann "BASIS_DIM", dim, 848c532df63SYohann "BASIS_NCOMP", ncomp, 849c532df63SYohann "BASIS_ELEMSIZE", CeedIntPow(P1d, dim), 850c532df63SYohann "BASIS_NQPT", CeedIntPow(Q1d, dim) 851c532df63SYohann ); CeedChk(ierr); 8524a6d4bbdSYohann Dudouit ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp); 853c532df63SYohann CeedChk(ierr); 8544a6d4bbdSYohann Dudouit ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad); 855c532df63SYohann CeedChk(ierr); 8564a6d4bbdSYohann Dudouit ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight); 857c532df63SYohann CeedChk(ierr); 858c532df63SYohann 859c532df63SYohann ierr = CeedBasisSetData(basis, (void *)&data); 860c532df63SYohann CeedChk(ierr); 861c532df63SYohann ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply", 862c532df63SYohann CeedBasisApplyTensor_Cuda_shared); 863c532df63SYohann CeedChk(ierr); 864c532df63SYohann ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy", 865c532df63SYohann CeedBasisDestroy_Cuda_shared); 866c532df63SYohann CeedChk(ierr); 867c532df63SYohann return 0; 868c532df63SYohann } 869