xref: /libCEED/backends/cuda-shared/ceed-cuda-shared-basis.c (revision ccf0fe6fd2ec6c692fac5dfe411b8e3ec625937d)
1c532df63SYohann // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2c532df63SYohann // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3c532df63SYohann // All Rights reserved. See files LICENSE and NOTICE for details.
4c532df63SYohann //
5c532df63SYohann // This file is part of CEED, a collection of benchmarks, miniapps, software
6c532df63SYohann // libraries and APIs for efficient high-order finite element and spectral
7c532df63SYohann // element discretizations for exascale applications. For more information and
8c532df63SYohann // source code availability see http://github.com/ceed.
9c532df63SYohann //
10c532df63SYohann // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11c532df63SYohann // a collaborative effort of two U.S. Department of Energy organizations (Office
12c532df63SYohann // of Science and the National Nuclear Security Administration) responsible for
13c532df63SYohann // the planning and preparation of a capable exascale ecosystem, including
14c532df63SYohann // software, applications, hardware, advanced system engineering and early
15c532df63SYohann // testbed platforms, in support of the nation's exascale computing imperative.
16c532df63SYohann 
17c532df63SYohann #include <ceed-backend.h>
18c532df63SYohann #include <ceed.h>
19c532df63SYohann #include "ceed-cuda-shared.h"
20c532df63SYohann #include "../cuda/ceed-cuda.h"
21c532df63SYohann 
22c532df63SYohann //*********************
23c532df63SYohann // shared mem kernels
24cb0b5415Sjeremylt // *INDENT-OFF*
25c532df63SYohann static const char *kernelsShared = QUOTE(
26c532df63SYohann 
27c532df63SYohann inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
28c532df63SYohann   for (int i = 0; i < Q1D; i++)
29c532df63SYohann     r_V[i] += r_U[i];
30c532df63SYohann }
31c532df63SYohann 
32c532df63SYohann //////////
33c532df63SYohann //  1D  //
34c532df63SYohann //////////
35c532df63SYohann 
36c532df63SYohann inline __device__ void readDofs1d(const int elem, const int tidx,
37d94769d2SYohann Dudouit                                   const int tidy, const int tidz,const int comp,
387f823360Sjeremylt                                   const int nelem, const CeedScalar *d_U,
397f823360Sjeremylt                                   CeedScalar *slice) {
40c532df63SYohann   for (int i = 0; i < P1D; i++)
41d94769d2SYohann Dudouit     slice[i+tidz*Q1D] = d_U[i + comp*P1D + elem*BASIS_NCOMP*P1D];
42c532df63SYohann   for (int i = P1D; i < Q1D; i++)
43d94769d2SYohann Dudouit     slice[i+tidz*Q1D] = 0.0;
44c532df63SYohann }
45c532df63SYohann 
46c532df63SYohann inline __device__ void writeDofs1d(const int elem, const int tidx,
47c532df63SYohann                                    const int tidy, const int comp,
48288c0443SJeremy L Thompson                                    const int nelem, const CeedScalar &r_V,
49288c0443SJeremy L Thompson                                    CeedScalar *d_V) {
50c532df63SYohann   if (tidx<P1D) {
51c532df63SYohann     d_V[tidx + comp*P1D + elem*BASIS_NCOMP*P1D] = r_V;
52c532df63SYohann   }
53c532df63SYohann }
54c532df63SYohann 
55c532df63SYohann inline __device__ void readQuads1d(const int elem, const int tidx,
56d94769d2SYohann Dudouit                                    const int tidy, const int tidz, const int comp,
57288c0443SJeremy L Thompson                                    const int dim, const int nelem,
58288c0443SJeremy L Thompson                                    const CeedScalar *d_U, CeedScalar *slice) {
59c532df63SYohann   for (int i = 0; i < Q1D; i++)
604d537eeaSYohann     slice[i+tidz*Q1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
614d537eeaSYohann                             dim*BASIS_NCOMP*nelem*Q1D];
62c532df63SYohann }
63c532df63SYohann 
64c532df63SYohann inline __device__ void writeQuads1d(const int elem, const int tidx,
65c532df63SYohann                                     const int tidy, const int comp,
66288c0443SJeremy L Thompson                                     const int dim, const int nelem,
67288c0443SJeremy L Thompson                                     const CeedScalar &r_V, CeedScalar *d_V) {
68c532df63SYohann   d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
69c532df63SYohann }
70c532df63SYohann 
71c532df63SYohann inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
72d94769d2SYohann Dudouit                                    const int tidy, const int tidz,
73288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
74288c0443SJeremy L Thompson                                    CeedScalar &V) {
75c532df63SYohann   V = 0.0;
76c532df63SYohann   for (int i = 0; i < P1D; ++i) {
77d94769d2SYohann Dudouit     V += B[i + tidx*P1D] * slice[i+tidz*Q1D];//contract x direction
78c532df63SYohann   }
79c532df63SYohann }
80c532df63SYohann 
81c532df63SYohann inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
82d94769d2SYohann Dudouit     const int tidy, const int tidz,
83c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
84c532df63SYohann   V = 0.0;
85c532df63SYohann   for (int i = 0; i < Q1D; ++i) {
86d94769d2SYohann Dudouit     V += B[tidx + i*P1D] * slice[i+tidz*Q1D];//contract x direction
87c532df63SYohann   }
88c532df63SYohann }
89c532df63SYohann 
90c532df63SYohann inline __device__ void interp1d(const CeedInt nelem, const int transpose,
91288c0443SJeremy L Thompson                                 const CeedScalar *c_B,
92288c0443SJeremy L Thompson                                 const CeedScalar *__restrict__ d_U,
93c532df63SYohann                                 CeedScalar *__restrict__ d_V,
94c532df63SYohann                                 CeedScalar *slice) {
95c532df63SYohann   CeedScalar r_V;
96c532df63SYohann   CeedScalar r_t;
97c532df63SYohann 
98c532df63SYohann   const int tidx = threadIdx.x;
99c532df63SYohann   const int tidy = threadIdx.y;
100d94769d2SYohann Dudouit   const int tidz = threadIdx.z;
101c532df63SYohann 
102c532df63SYohann 
103c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
104c532df63SYohann        elem += gridDim.x*blockDim.z) {
105c532df63SYohann     for(int comp=0; comp<BASIS_NCOMP; comp++) {
106c532df63SYohann       if(!transpose) {
107d94769d2SYohann Dudouit         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
108d94769d2SYohann Dudouit         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
109c532df63SYohann         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
110c532df63SYohann       } else {
111d94769d2SYohann Dudouit         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
112d94769d2SYohann Dudouit         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
113c532df63SYohann         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
114c532df63SYohann       }
115c532df63SYohann     }
116c532df63SYohann   }
117c532df63SYohann }
118c532df63SYohann 
119c532df63SYohann inline __device__ void grad1d(const CeedInt nelem, const int transpose,
120c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
121288c0443SJeremy L Thompson                               const CeedScalar *__restrict__ d_U,
122288c0443SJeremy L Thompson                               CeedScalar *__restrict__ d_V,
123c532df63SYohann                               CeedScalar *slice) {
124c532df63SYohann   CeedScalar r_U;
125c532df63SYohann   CeedScalar r_V;
126c532df63SYohann 
127c532df63SYohann   const int tidx = threadIdx.x;
128d94769d2SYohann Dudouit   const int tidy = threadIdx.y;
129d94769d2SYohann Dudouit   const int tidz = threadIdx.z;
130c532df63SYohann   int dim;
131c532df63SYohann 
132c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
133c532df63SYohann        elem += gridDim.x*blockDim.z) {
134c532df63SYohann     for(int comp=0; comp<BASIS_NCOMP; comp++) {
135c532df63SYohann       if(!transpose) {
136d94769d2SYohann Dudouit         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
137d94769d2SYohann Dudouit         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
138c532df63SYohann         dim = 0;
139c532df63SYohann         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
140c532df63SYohann       } else {
141c532df63SYohann         dim = 0;
142d94769d2SYohann Dudouit         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
143d94769d2SYohann Dudouit         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
144c532df63SYohann         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
145c532df63SYohann       }
146c532df63SYohann     }
147c532df63SYohann   }
148c532df63SYohann }
149c532df63SYohann //////////
150c532df63SYohann //  2D  //
151c532df63SYohann //////////
152c532df63SYohann 
153c532df63SYohann inline __device__ void readDofs2d(const int elem, const int tidx,
154c532df63SYohann                                   const int tidy, const int comp,
155288c0443SJeremy L Thompson                                   const int nelem, const CeedScalar *d_U,
156288c0443SJeremy L Thompson                                   CeedScalar &U) {
157c532df63SYohann   U = (tidx<P1D
1587f823360Sjeremylt        && tidy<P1D) ? d_U[tidx + tidy*P1D + comp*P1D*P1D +
1597f823360Sjeremylt                           elem*BASIS_NCOMP*P1D*P1D ] :
160c532df63SYohann       0.0;
161c532df63SYohann }
162c532df63SYohann 
163c532df63SYohann inline __device__ void writeDofs2d(const int elem, const int tidx,
164c532df63SYohann                                    const int tidy, const int comp,
165288c0443SJeremy L Thompson                                    const int nelem, const CeedScalar &r_V,
166288c0443SJeremy L Thompson                                    CeedScalar *d_V) {
167c532df63SYohann   if (tidx<P1D && tidy<P1D) {
168c532df63SYohann     d_V[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D ] = r_V;
169c532df63SYohann   }
170c532df63SYohann }
171c532df63SYohann 
172c532df63SYohann inline __device__ void readQuads2d(const int elem, const int tidx,
173c532df63SYohann                                    const int tidy, const int comp,
174288c0443SJeremy L Thompson                                    const int dim, const int nelem,
175288c0443SJeremy L Thompson                                    const CeedScalar *d_U, CeedScalar &U ) {
176c532df63SYohann   U = d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
177c532df63SYohann                dim*BASIS_NCOMP*nelem*Q1D*Q1D];
178c532df63SYohann }
179c532df63SYohann 
180c532df63SYohann inline __device__ void writeQuads2d(const int elem, const int tidx,
181c532df63SYohann                                     const int tidy, const int comp,
182288c0443SJeremy L Thompson                                     const int dim, const int nelem,
183288c0443SJeremy L Thompson                                     const CeedScalar &r_V, CeedScalar *d_V) {
184c532df63SYohann   d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
185c532df63SYohann            dim*BASIS_NCOMP*nelem*Q1D*Q1D ] = r_V;
186c532df63SYohann }
187c532df63SYohann 
188c532df63SYohann inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
1894247ecf3SYohann Dudouit                                    const int tidy, const int tidz,
190288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
191288c0443SJeremy L Thompson                                    CeedScalar &V) {
1924247ecf3SYohann Dudouit   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
193c532df63SYohann   __syncthreads();
194c532df63SYohann   V = 0.0;
195c532df63SYohann   for (int i = 0; i < P1D; ++i) {
1964247ecf3SYohann Dudouit     V += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction
197c532df63SYohann   }
198c532df63SYohann   __syncthreads();
199c532df63SYohann }
200c532df63SYohann 
201c532df63SYohann inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
2024247ecf3SYohann Dudouit                                    const int tidy, const int tidz,
203288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
204288c0443SJeremy L Thompson                                    CeedScalar &V) {
2054247ecf3SYohann Dudouit   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
206c532df63SYohann   __syncthreads();
207c532df63SYohann   V = 0.0;
208c532df63SYohann   for (int i = 0; i < P1D; ++i) {
2094247ecf3SYohann Dudouit     V += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction
210c532df63SYohann   }
211c532df63SYohann   __syncthreads();
212c532df63SYohann }
213c532df63SYohann 
214c532df63SYohann inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
2154247ecf3SYohann Dudouit     const int tidy, const int tidz,
216c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
2174247ecf3SYohann Dudouit   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
218c532df63SYohann   __syncthreads();
219c532df63SYohann   V = 0.0;
220c532df63SYohann   if (tidy<P1D) {
221c532df63SYohann     for (int i = 0; i < Q1D; ++i) {
2224247ecf3SYohann Dudouit       V += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction
223c532df63SYohann     }
224c532df63SYohann   }
225c532df63SYohann   __syncthreads();
226c532df63SYohann }
227c532df63SYohann 
228c532df63SYohann inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
2294247ecf3SYohann Dudouit     const int tidy, const int tidz,
230c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
2314247ecf3SYohann Dudouit   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
232c532df63SYohann   __syncthreads();
233c532df63SYohann   V = 0.0;
234c532df63SYohann   if (tidx<P1D) {
235c532df63SYohann     for (int i = 0; i < Q1D; ++i) {
2364247ecf3SYohann Dudouit       V += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction
237c532df63SYohann     }
238c532df63SYohann   }
239c532df63SYohann   __syncthreads();
240c532df63SYohann }
241c532df63SYohann 
242c532df63SYohann inline __device__ void interp2d(const CeedInt nelem, const int transpose,
243288c0443SJeremy L Thompson                                 const CeedScalar *c_B,
244288c0443SJeremy L Thompson                                 const CeedScalar *__restrict__ d_U,
245c532df63SYohann                                 CeedScalar *__restrict__ d_V,
246c532df63SYohann                                 CeedScalar *slice) {
247c532df63SYohann   CeedScalar r_V;
248c532df63SYohann   CeedScalar r_t;
249c532df63SYohann 
250c532df63SYohann   const int tidx = threadIdx.x;
251c532df63SYohann   const int tidy = threadIdx.y;
2524247ecf3SYohann Dudouit   const int tidz = threadIdx.z;
2534247ecf3SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
2544247ecf3SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
2554247ecf3SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
256c532df63SYohann 
2574247ecf3SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
2584247ecf3SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
2594247ecf3SYohann Dudouit     const int comp = tidz%BASIS_NCOMP;
260c532df63SYohann     r_V = 0.0;
261c532df63SYohann     r_t = 0.0;
262c532df63SYohann     if(!transpose) {
263c532df63SYohann       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
2644247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
2654247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
266c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
267c532df63SYohann     } else {
268c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
2694247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
2704247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
271c532df63SYohann       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
272c532df63SYohann     }
273c532df63SYohann   }
274c532df63SYohann }
275c532df63SYohann 
276c532df63SYohann inline __device__ void grad2d(const CeedInt nelem, const int transpose,
277c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
2787f823360Sjeremylt                               const CeedScalar *__restrict__ d_U,
2797f823360Sjeremylt                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
280c532df63SYohann   CeedScalar r_U;
281c532df63SYohann   CeedScalar r_V;
282c532df63SYohann   CeedScalar r_t;
283c532df63SYohann 
284c532df63SYohann   const int tidx = threadIdx.x;
285c532df63SYohann   const int tidy = threadIdx.y;
2864247ecf3SYohann Dudouit   const int tidz = threadIdx.z;
2874247ecf3SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
2884247ecf3SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
2894247ecf3SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
290c532df63SYohann   int dim;
291c532df63SYohann 
2924247ecf3SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
2934247ecf3SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
294c532df63SYohann     if(!transpose) {
295c532df63SYohann       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
2964247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
2974247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
298c532df63SYohann       dim = 0;
299c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
3004247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
3014247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
302c532df63SYohann       dim = 1;
303c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
304c532df63SYohann     } else {
305c532df63SYohann       dim = 0;
306c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
3074247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
3084247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
309c532df63SYohann       dim = 1;
310c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
3114247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
3124247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
313c532df63SYohann       r_V+=r_U;
314c532df63SYohann       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
315c532df63SYohann     }
316c532df63SYohann   }
317c532df63SYohann }
318c532df63SYohann //////////
319c532df63SYohann //  3D  //
320c532df63SYohann //////////
321c532df63SYohann 
322c532df63SYohann inline __device__ void readDofs3d(const int elem, const int tidx,
323c532df63SYohann                                   const int tidy, const int comp,
3247f823360Sjeremylt                                   const int nelem, const CeedScalar *d_U,
3257f823360Sjeremylt                                   CeedScalar *r_U) {
326c532df63SYohann   for (int i = 0; i < P1D; i++)
327c532df63SYohann     r_U[i] = (tidx<P1D
328c532df63SYohann               && tidy<P1D) ? d_U[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
329c532df63SYohann                                       elem*BASIS_NCOMP*P1D*P1D*P1D ] : 0.0;
330c532df63SYohann   for (int i = P1D; i < Q1D; i++)
331c532df63SYohann     r_U[i] = 0.0;
332c532df63SYohann }
333c532df63SYohann 
334c532df63SYohann inline __device__ void readQuads3d(const int elem, const int tidx,
335c532df63SYohann                                    const int tidy, const int comp,
3367f823360Sjeremylt                                    const int dim, const int nelem,
3377f823360Sjeremylt                                    const CeedScalar *d_U, CeedScalar *r_U) {
338c532df63SYohann   for (int i = 0; i < Q1D; i++)
339c532df63SYohann     r_U[i] = d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
340c532df63SYohann                  comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D];
341c532df63SYohann }
342c532df63SYohann 
343c532df63SYohann inline __device__ void writeDofs3d(const int elem, const int tidx,
344c532df63SYohann                                    const int tidy, const int comp,
3457f823360Sjeremylt                                    const int nelem, const CeedScalar *r_V,
3467f823360Sjeremylt                                    CeedScalar *d_V) {
347c532df63SYohann   if (tidx<P1D && tidy<P1D) {
348c532df63SYohann     for (int i = 0; i < P1D; i++)
349c532df63SYohann       d_V[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
350c532df63SYohann           elem*BASIS_NCOMP*P1D*P1D*P1D ] = r_V[i];
351c532df63SYohann   }
352c532df63SYohann }
353c532df63SYohann 
354c532df63SYohann inline __device__ void writeQuads3d(const int elem, const int tidx,
355c532df63SYohann                                     const int tidy, const int comp,
3567f823360Sjeremylt                                     const int dim, const int nelem,
3577f823360Sjeremylt                                     const CeedScalar *r_V, CeedScalar *d_V) {
358c532df63SYohann   for (int i = 0; i < Q1D; i++)
359c532df63SYohann     d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
360c532df63SYohann         dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D ] = r_V[i];
361c532df63SYohann }
362c532df63SYohann 
363c532df63SYohann inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
364698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
3657f823360Sjeremylt                                    const CeedScalar *U, const CeedScalar *B,
3667f823360Sjeremylt                                    CeedScalar *V) {
367c532df63SYohann   for (int k = 0; k < P1D; ++k) {
368698ebc35SYohann Dudouit     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
369c532df63SYohann     __syncthreads();
370c532df63SYohann     V[k] = 0.0;
371c532df63SYohann     for (int i = 0; i < P1D; ++i) {
3724d537eeaSYohann       V[k] += B[i + tidx*P1D] * slice[i + tidy*Q1D +
3734d537eeaSYohann                                       tidz*Q1D*Q1D];//contract x direction
374c532df63SYohann     }
375c532df63SYohann     __syncthreads();
376c532df63SYohann   }
377c532df63SYohann }
378c532df63SYohann 
379c532df63SYohann inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
380698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
3817f823360Sjeremylt                                    const CeedScalar *U, const CeedScalar *B,
3827f823360Sjeremylt                                    CeedScalar *V) {
383c532df63SYohann   for (int k = 0; k < P1D; ++k) {
384698ebc35SYohann Dudouit     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
385c532df63SYohann     __syncthreads();
386c532df63SYohann     V[k] = 0.0;
387c532df63SYohann     for (int i = 0; i < P1D; ++i) {
3884d537eeaSYohann       V[k] += B[i + tidy*P1D] * slice[tidx + i*Q1D +
3894d537eeaSYohann                                       tidz*Q1D*Q1D];//contract y direction
390c532df63SYohann     }
391c532df63SYohann     __syncthreads();
392c532df63SYohann   }
393c532df63SYohann }
394c532df63SYohann 
395c532df63SYohann inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
396698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
3977f823360Sjeremylt                                    const CeedScalar *U, const CeedScalar *B,
3987f823360Sjeremylt                                    CeedScalar *V) {
399c532df63SYohann   for (int k = 0; k < Q1D; ++k) {
400c532df63SYohann     V[k] = 0.0;
401c532df63SYohann     for (int i = 0; i < P1D; ++i) {
402c532df63SYohann       V[k] += B[i + k*P1D] * U[i];//contract z direction
403c532df63SYohann     }
404c532df63SYohann   }
405c532df63SYohann }
406c532df63SYohann 
407c532df63SYohann inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
408698ebc35SYohann Dudouit     const int tidy, const int tidz,
409c532df63SYohann     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
410c532df63SYohann   for (int k = 0; k < Q1D; ++k) {
411c532df63SYohann     V[k] = 0.0;
412c532df63SYohann     if (k<P1D) {
413c532df63SYohann       for (int i = 0; i < Q1D; ++i) {
414c532df63SYohann         V[k] += B[k + i*P1D] * U[i];//contract z direction
415c532df63SYohann       }
416c532df63SYohann     }
417c532df63SYohann   }
418c532df63SYohann }
419c532df63SYohann 
420c532df63SYohann inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
421698ebc35SYohann Dudouit     const int tidy, const int tidz,
422c532df63SYohann     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
423c532df63SYohann   for (int k = 0; k < P1D; ++k) {
424698ebc35SYohann Dudouit     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
425c532df63SYohann     __syncthreads();
426c532df63SYohann     V[k] = 0.0;
427c532df63SYohann     if (tidy<P1D) {
428c532df63SYohann       for (int i = 0; i < Q1D; ++i) {
4294d537eeaSYohann         V[k] += B[tidy + i*P1D] * slice[tidx + i*Q1D +
4304d537eeaSYohann                                         tidz*Q1D*Q1D];//contract y direction
431c532df63SYohann       }
432c532df63SYohann     }
433c532df63SYohann     __syncthreads();
434c532df63SYohann   }
435c532df63SYohann }
436c532df63SYohann 
437c532df63SYohann inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
438698ebc35SYohann Dudouit     const int tidy, const int tidz,
439c532df63SYohann     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
440c532df63SYohann   for (int k = 0; k < P1D; ++k) {
441698ebc35SYohann Dudouit     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
442c532df63SYohann     __syncthreads();
443c532df63SYohann     V[k] = 0.0;
444c532df63SYohann     if (tidx<P1D) {
445c532df63SYohann       for (int i = 0; i < Q1D; ++i) {
4464d537eeaSYohann         V[k] += B[tidx + i*P1D] * slice[i + tidy*Q1D +
4474d537eeaSYohann                                         tidz*Q1D*Q1D];//contract x direction
448c532df63SYohann       }
449c532df63SYohann     }
450c532df63SYohann     __syncthreads();
451c532df63SYohann   }
452c532df63SYohann }
453c532df63SYohann 
454c532df63SYohann inline __device__ void interp3d(const CeedInt nelem, const int transpose,
4557f823360Sjeremylt                                 const CeedScalar *c_B,
4567f823360Sjeremylt                                 const CeedScalar *__restrict__ d_U,
457c532df63SYohann                                 CeedScalar *__restrict__ d_V,
458c532df63SYohann                                 CeedScalar *slice) {
459c532df63SYohann   CeedScalar r_V[Q1D];
460c532df63SYohann   CeedScalar r_t[Q1D];
461c532df63SYohann 
462c532df63SYohann   const int tidx = threadIdx.x;
463c532df63SYohann   const int tidy = threadIdx.y;
464698ebc35SYohann Dudouit   const int tidz = threadIdx.z;
465698ebc35SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
466698ebc35SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
467698ebc35SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
468c532df63SYohann 
469698ebc35SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
470698ebc35SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
471c532df63SYohann     for (int i = 0; i < Q1D; ++i) {
472c532df63SYohann       r_V[i] = 0.0;
473c532df63SYohann       r_t[i] = 0.0;
474c532df63SYohann     }
475c532df63SYohann     if(!transpose) {
476c532df63SYohann       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
477698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
478698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
479698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
480c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
481c532df63SYohann     } else {
482c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
483698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
484698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
485698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
486c532df63SYohann       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
487c532df63SYohann     }
488c532df63SYohann   }
489c532df63SYohann }
490c532df63SYohann 
491c532df63SYohann inline __device__ void grad3d(const CeedInt nelem, const int transpose,
492c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
4937f823360Sjeremylt                               const CeedScalar *__restrict__ d_U,
4947f823360Sjeremylt                               CeedScalar *__restrict__ d_V,
495c532df63SYohann                               CeedScalar *slice) {
496c532df63SYohann   //use P1D for one of these
497c532df63SYohann   CeedScalar r_U[Q1D];
498c532df63SYohann   CeedScalar r_V[Q1D];
499c532df63SYohann   CeedScalar r_t[Q1D];
500c532df63SYohann 
501c532df63SYohann   const int tidx = threadIdx.x;
502c532df63SYohann   const int tidy = threadIdx.y;
503698ebc35SYohann Dudouit   const int tidz = threadIdx.z;
504698ebc35SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
505698ebc35SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
506698ebc35SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
507c532df63SYohann   int dim;
508c532df63SYohann 
509698ebc35SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
510698ebc35SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
511c532df63SYohann     if(!transpose) {
512c532df63SYohann       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
513698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
514698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
515698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
516c532df63SYohann       dim = 0;
517c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
518698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
519698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
520698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
521c532df63SYohann       dim = 1;
522c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
523698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
524698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
525698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
526c532df63SYohann       dim = 2;
527c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
528c532df63SYohann     } else {
529c532df63SYohann       dim = 0;
530c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
531698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
532698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
533698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
534c532df63SYohann       dim = 1;
535c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
536698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
537698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
538698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
539c532df63SYohann       add(r_V, r_t);
540c532df63SYohann       dim = 2;
541c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
542698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
543698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
544698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
545c532df63SYohann       add(r_V, r_t);
546c532df63SYohann       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
547c532df63SYohann     }
548c532df63SYohann   }
549c532df63SYohann }
550c532df63SYohann 
551c532df63SYohann /////////////
552c532df63SYohann // Kernels //
553c532df63SYohann /////////////
554c532df63SYohann extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
5557f823360Sjeremylt                                   const CeedScalar *c_B,
5567f823360Sjeremylt                                   const CeedScalar *__restrict__ d_U,
557c532df63SYohann                                   CeedScalar *__restrict__ d_V) {
558074be161SYohann Dudouit   extern __shared__ double slice[];
559c532df63SYohann   if (BASIS_DIM==1) {
560c532df63SYohann     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
561c532df63SYohann   } else if (BASIS_DIM==2) {
562c532df63SYohann     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
563c532df63SYohann   } else if (BASIS_DIM==3) {
564c532df63SYohann     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
565c532df63SYohann   }
566c532df63SYohann }
567c532df63SYohann 
568c532df63SYohann extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
569c532df63SYohann                                 const CeedScalar *c_B, const CeedScalar *c_G,
5707f823360Sjeremylt                                 const CeedScalar *__restrict__ d_U,
5717f823360Sjeremylt                                 CeedScalar *__restrict__ d_V) {
572074be161SYohann Dudouit   extern __shared__ double slice[];
573c532df63SYohann   if (BASIS_DIM==1) {
574c532df63SYohann     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
575c532df63SYohann   } else if (BASIS_DIM==2) {
576c532df63SYohann     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
577c532df63SYohann   } else if (BASIS_DIM==3) {
578c532df63SYohann     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
579c532df63SYohann   }
580c532df63SYohann }
581c532df63SYohann 
582c532df63SYohann /////////////
583c532df63SYohann // Weights //
584c532df63SYohann /////////////
585c532df63SYohann __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
586c532df63SYohann                          CeedScalar *w) {
587074be161SYohann Dudouit   const int tid = threadIdx.x;
588074be161SYohann Dudouit   const CeedScalar weight = qweight1d[tid];
589074be161SYohann Dudouit   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
590074be161SYohann Dudouit        elem += gridDim.x*blockDim.y) {
591074be161SYohann Dudouit     const int ind = elem*Q1D + tid;
592074be161SYohann Dudouit     w[ind] = weight;
593c532df63SYohann   }
594c532df63SYohann }
595c532df63SYohann 
596c532df63SYohann __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
597c532df63SYohann                          CeedScalar *w) {
598074be161SYohann Dudouit   const int i = threadIdx.x;
599074be161SYohann Dudouit   const int j = threadIdx.y;
600074be161SYohann Dudouit   const CeedScalar weight = qweight1d[i]*qweight1d[j];
601074be161SYohann Dudouit   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
602074be161SYohann Dudouit        elem += gridDim.x*blockDim.z) {
603074be161SYohann Dudouit     const int ind = elem*Q1D*Q1D + i + j*Q1D;
604074be161SYohann Dudouit     w[ind] = weight;
605c532df63SYohann   }
606c532df63SYohann }
607c532df63SYohann 
608c532df63SYohann __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
609c532df63SYohann                          CeedScalar *w) {
610074be161SYohann Dudouit   const int i = threadIdx.x;
611074be161SYohann Dudouit   const int j = threadIdx.y;
612074be161SYohann Dudouit   const int k = threadIdx.z;
613074be161SYohann Dudouit   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
614074be161SYohann Dudouit   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
615074be161SYohann Dudouit     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
616074be161SYohann Dudouit     w[ind] = weight;
617c532df63SYohann   }
618c532df63SYohann }
619c532df63SYohann 
620c532df63SYohann extern "C" __global__ void weight(const CeedInt nelem,
6217f823360Sjeremylt                                   const CeedScalar *__restrict__ qweight1d,
6227f823360Sjeremylt                                   CeedScalar *__restrict__ v) {
623c532df63SYohann   if (BASIS_DIM==1) {
624c532df63SYohann     weight1d(nelem, qweight1d, v);
625c532df63SYohann   } else if (BASIS_DIM==2) {
626c532df63SYohann     weight2d(nelem, qweight1d, v);
627c532df63SYohann   } else if (BASIS_DIM==3) {
628c532df63SYohann     weight3d(nelem, qweight1d, v);
629c532df63SYohann   }
630c532df63SYohann }
631c532df63SYohann 
632c532df63SYohann );
633cb0b5415Sjeremylt // *INDENT-ON*
634c532df63SYohann 
635c532df63SYohann int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
636c532df63SYohann                        CeedScalar **c_B);
637c532df63SYohann int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
6387f823360Sjeremylt                            CeedInt Q1d, CeedScalar **c_B_ptr,
6397f823360Sjeremylt                            CeedScalar **c_G_ptr);
640c532df63SYohann 
641c532df63SYohann int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem,
642c532df63SYohann                                      CeedTransposeMode tmode,
6437f823360Sjeremylt                                      CeedEvalMode emode, CeedVector u,
6447f823360Sjeremylt                                      CeedVector v) {
645c532df63SYohann   int ierr;
646c532df63SYohann   Ceed ceed;
647c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
648c532df63SYohann   Ceed_Cuda_shared *ceed_Cuda;
649c532df63SYohann   CeedGetData(ceed, (void *) &ceed_Cuda); CeedChk(ierr);
650c532df63SYohann   CeedBasis_Cuda_shared *data;
651c532df63SYohann   CeedBasisGetData(basis, (void *)&data); CeedChk(ierr);
652c532df63SYohann   const CeedInt transpose = tmode == CEED_TRANSPOSE;
6534247ecf3SYohann Dudouit   CeedInt dim, ncomp;
654074be161SYohann Dudouit   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
6554247ecf3SYohann Dudouit   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
656c532df63SYohann 
657c532df63SYohann   const CeedScalar *d_u;
658c532df63SYohann   CeedScalar *d_v;
659c532df63SYohann   if(emode!=CEED_EVAL_WEIGHT) {
660c532df63SYohann     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
661c532df63SYohann   }
662c532df63SYohann   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
663c532df63SYohann 
664c532df63SYohann   if (tmode == CEED_TRANSPOSE) {
665c532df63SYohann     CeedInt length;
666c532df63SYohann     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
667c532df63SYohann     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
668c532df63SYohann   }
669c532df63SYohann   if (emode == CEED_EVAL_INTERP) {
670c532df63SYohann     CeedInt P1d, Q1d;
671c532df63SYohann     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
672c532df63SYohann     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
673c532df63SYohann     ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
674c532df63SYohann     CeedChk(ierr);
675cb0b5415Sjeremylt     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
676*ccf0fe6fSjeremylt                           &d_u, &d_v
677*ccf0fe6fSjeremylt                          };
6784d537eeaSYohann     if (dim==1) {
679d94769d2SYohann Dudouit       CeedInt elemsPerBlock = 32;
6804d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
6814d537eeaSYohann                                              ? 1 : 0 );
682d94769d2SYohann Dudouit       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
6834d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, 1,
6844d537eeaSYohann                                         elemsPerBlock, sharedMem,
685c532df63SYohann                                         interpargs);
686c532df63SYohann       CeedChk(ierr);
687074be161SYohann Dudouit     } else if (dim==2) {
6884247ecf3SYohann Dudouit       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
6894247ecf3SYohann Dudouit       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
6904d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
6914d537eeaSYohann                                              ? 1 : 0 );
6924247ecf3SYohann Dudouit       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
6934d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d,
6944d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
695074be161SYohann Dudouit                                         interpargs);
696074be161SYohann Dudouit       CeedChk(ierr);
697074be161SYohann Dudouit     } else if (dim==3) {
6983f63d318SYohann Dudouit       CeedInt elemsPerBlock = 1;
6994d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
7004d537eeaSYohann                                              ? 1 : 0 );
701698ebc35SYohann Dudouit       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
7024d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d,
7034d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
704074be161SYohann Dudouit                                         interpargs);
705074be161SYohann Dudouit       CeedChk(ierr);
706074be161SYohann Dudouit     }
707c532df63SYohann   } else if (emode == CEED_EVAL_GRAD) {
708c532df63SYohann     CeedInt P1d, Q1d;
709c532df63SYohann     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
710c532df63SYohann     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
711c532df63SYohann     ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
712c532df63SYohann                                   Q1d, &data->c_B, &data->c_G);
713c532df63SYohann     CeedChk(ierr);
714cb0b5415Sjeremylt     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
715*ccf0fe6fSjeremylt                         &data->c_G, &d_u, &d_v
716*ccf0fe6fSjeremylt                        };
7174d537eeaSYohann     if (dim==1) {
718d94769d2SYohann Dudouit       CeedInt elemsPerBlock = 32;
7194d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
7204d537eeaSYohann                                              ? 1 : 0 );
721d94769d2SYohann Dudouit       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
7224d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, 1, elemsPerBlock,
7234d537eeaSYohann                                         sharedMem,
724c532df63SYohann                                         gradargs);
725c532df63SYohann       CeedChk(ierr);
726074be161SYohann Dudouit     } else if (dim==2) {
7274247ecf3SYohann Dudouit       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
7284247ecf3SYohann Dudouit       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
7294d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
7304d537eeaSYohann                                              ? 1 : 0 );
7314247ecf3SYohann Dudouit       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
7324d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d,
7334d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
734074be161SYohann Dudouit                                         gradargs);
735074be161SYohann Dudouit       CeedChk(ierr);
736074be161SYohann Dudouit     } else if (dim==3) {
7373f63d318SYohann Dudouit       CeedInt elemsPerBlock = 1;
7384d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
7394d537eeaSYohann                                              ? 1 : 0 );
740698ebc35SYohann Dudouit       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
7414d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d,
7424d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
743074be161SYohann Dudouit                                         gradargs);
744074be161SYohann Dudouit       CeedChk(ierr);
745074be161SYohann Dudouit     }
746c532df63SYohann   } else if (emode == CEED_EVAL_WEIGHT) {
747074be161SYohann Dudouit     CeedInt Q1d;
748074be161SYohann Dudouit     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
749c532df63SYohann     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
750074be161SYohann Dudouit     if(dim==1) {
751074be161SYohann Dudouit       const CeedInt elemsPerBlock = 32/Q1d;
7524d537eeaSYohann       const CeedInt gridsize = nelem/elemsPerBlock + ( (
7534d537eeaSYohann                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
7547f823360Sjeremylt       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d,
7557f823360Sjeremylt                                   elemsPerBlock, 1, weightargs);
7561226057fSYohann Dudouit       CeedChk(ierr);
757074be161SYohann Dudouit     } else if(dim==2) {
758717ff8a3SYohann Dudouit       const CeedInt optElems = 32/(Q1d*Q1d);
759717ff8a3SYohann Dudouit       const CeedInt elemsPerBlock = optElems>0?optElems:1;
7604d537eeaSYohann       const CeedInt gridsize = nelem/elemsPerBlock + ( (
7614d537eeaSYohann                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
7624d537eeaSYohann       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d,
7634d537eeaSYohann                                   elemsPerBlock, weightargs);
7641226057fSYohann Dudouit       CeedChk(ierr);
765074be161SYohann Dudouit     } else if(dim==3) {
766074be161SYohann Dudouit       const CeedInt gridsize = nelem;
7674d537eeaSYohann       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
7684d537eeaSYohann                                   weightargs);
7691226057fSYohann Dudouit       CeedChk(ierr);
770074be161SYohann Dudouit     }
771c532df63SYohann   }
772c532df63SYohann 
773c532df63SYohann   if(emode!=CEED_EVAL_WEIGHT) {
774c532df63SYohann     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
775c532df63SYohann   }
776c532df63SYohann   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
777c532df63SYohann 
778c532df63SYohann   return 0;
779c532df63SYohann }
780c532df63SYohann 
781c532df63SYohann static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
782c532df63SYohann   int ierr;
783c532df63SYohann   Ceed ceed;
784c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
785c532df63SYohann 
786c532df63SYohann   CeedBasis_Cuda_shared *data;
787c532df63SYohann   ierr = CeedBasisGetData(basis, (void *) &data); CeedChk(ierr);
788c532df63SYohann 
789c532df63SYohann   CeedChk_Cu(ceed, cuModuleUnload(data->module));
790c532df63SYohann 
791c532df63SYohann   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr);
792c532df63SYohann   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr);
793c532df63SYohann   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr);
794c532df63SYohann 
795c532df63SYohann   ierr = CeedFree(&data); CeedChk(ierr);
796c532df63SYohann 
797c532df63SYohann   return 0;
798c532df63SYohann }
799c532df63SYohann 
800c532df63SYohann int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
801c532df63SYohann                                         const CeedScalar *interp1d,
802c532df63SYohann                                         const CeedScalar *grad1d,
803c532df63SYohann                                         const CeedScalar *qref1d,
804c532df63SYohann                                         const CeedScalar *qweight1d,
805c532df63SYohann                                         CeedBasis basis) {
806c532df63SYohann   int ierr;
807c532df63SYohann   Ceed ceed;
808c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
8094d537eeaSYohann   if (Q1d<P1d) {
8101226057fSYohann Dudouit     return CeedError(ceed, 1, "Backend does not implement underintegrated basis.");
8111226057fSYohann Dudouit   }
812c532df63SYohann   CeedBasis_Cuda_shared *data;
813c532df63SYohann   ierr = CeedCalloc(1, &data); CeedChk(ierr);
814c532df63SYohann 
815c532df63SYohann   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
816c532df63SYohann   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr);
817c532df63SYohann   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
818c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
819c532df63SYohann 
820c532df63SYohann   const CeedInt iBytes = qBytes * P1d;
821c532df63SYohann   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr);
822c532df63SYohann   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
823c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
824c532df63SYohann 
825c532df63SYohann   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr);
826c532df63SYohann   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
827c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
828c532df63SYohann 
829ac421f39SYohann   data->d_collograd1d = NULL;
830ac421f39SYohann   if (dim==3 && Q1d >= P1d) {
831ac421f39SYohann     CeedScalar *collograd1d;
832ac421f39SYohann     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChk(ierr);
833ac421f39SYohann     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChk(ierr);
834ac421f39SYohann     ierr = cudaMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
835ac421f39SYohann     CeedChk_Cu(ceed, ierr);
836ac421f39SYohann     ierr = cudaMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
837ac421f39SYohann                       cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
838ac421f39SYohann   }
839ac421f39SYohann 
840c532df63SYohann   CeedInt ncomp;
841c532df63SYohann   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
8424a6d4bbdSYohann Dudouit   ierr = CeedCompileCuda(ceed, kernelsShared, &data->module, 7,
843c532df63SYohann                          "Q1D", Q1d,
844c532df63SYohann                          "P1D", P1d,
845c532df63SYohann                          "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
846c532df63SYohann                              Q1d : P1d, dim),
847c532df63SYohann                          "BASIS_DIM", dim,
848c532df63SYohann                          "BASIS_NCOMP", ncomp,
849c532df63SYohann                          "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
850c532df63SYohann                          "BASIS_NQPT", CeedIntPow(Q1d, dim)
851c532df63SYohann                         ); CeedChk(ierr);
8524a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp);
853c532df63SYohann   CeedChk(ierr);
8544a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad);
855c532df63SYohann   CeedChk(ierr);
8564a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight);
857c532df63SYohann   CeedChk(ierr);
858c532df63SYohann 
859c532df63SYohann   ierr = CeedBasisSetData(basis, (void *)&data);
860c532df63SYohann   CeedChk(ierr);
861c532df63SYohann   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
862c532df63SYohann                                 CeedBasisApplyTensor_Cuda_shared);
863c532df63SYohann   CeedChk(ierr);
864c532df63SYohann   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
865c532df63SYohann                                 CeedBasisDestroy_Cuda_shared);
866c532df63SYohann   CeedChk(ierr);
867c532df63SYohann   return 0;
868c532df63SYohann }
869