xref: /libCEED/backends/cuda-shared/ceed-cuda-shared-basis.c (revision 288c044332e33f37503f09b6484fec9d0a55fba1)
1c532df63SYohann // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2c532df63SYohann // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3c532df63SYohann // All Rights reserved. See files LICENSE and NOTICE for details.
4c532df63SYohann //
5c532df63SYohann // This file is part of CEED, a collection of benchmarks, miniapps, software
6c532df63SYohann // libraries and APIs for efficient high-order finite element and spectral
7c532df63SYohann // element discretizations for exascale applications. For more information and
8c532df63SYohann // source code availability see http://github.com/ceed.
9c532df63SYohann //
10c532df63SYohann // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11c532df63SYohann // a collaborative effort of two U.S. Department of Energy organizations (Office
12c532df63SYohann // of Science and the National Nuclear Security Administration) responsible for
13c532df63SYohann // the planning and preparation of a capable exascale ecosystem, including
14c532df63SYohann // software, applications, hardware, advanced system engineering and early
15c532df63SYohann // testbed platforms, in support of the nation's exascale computing imperative.
16c532df63SYohann 
17c532df63SYohann #include <ceed-backend.h>
18c532df63SYohann #include <ceed.h>
19c532df63SYohann #include "ceed-cuda-shared.h"
20c532df63SYohann #include "../cuda/ceed-cuda.h"
21c532df63SYohann 
22c532df63SYohann //*********************
23c532df63SYohann // shared mem kernels
24c532df63SYohann static const char *kernelsShared = QUOTE(
25c532df63SYohann 
26c532df63SYohann inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
27c532df63SYohann   for (int i = 0; i < Q1D; i++)
28c532df63SYohann     r_V[i] += r_U[i];
29c532df63SYohann }
30c532df63SYohann 
31c532df63SYohann //////////
32c532df63SYohann //  1D  //
33c532df63SYohann //////////
34c532df63SYohann 
35c532df63SYohann inline __device__ void readDofs1d(const int elem, const int tidx,
36d94769d2SYohann Dudouit                                   const int tidy, const int tidz,const int comp,
37c532df63SYohann                                   const int nelem, const CeedScalar *d_U, CeedScalar *slice) {
38c532df63SYohann   for (int i = 0; i < P1D; i++)
39d94769d2SYohann Dudouit     slice[i+tidz*Q1D] = d_U[i + comp*P1D + elem*BASIS_NCOMP*P1D];
40c532df63SYohann   for (int i = P1D; i < Q1D; i++)
41d94769d2SYohann Dudouit     slice[i+tidz*Q1D] = 0.0;
42c532df63SYohann }
43c532df63SYohann 
44c532df63SYohann inline __device__ void writeDofs1d(const int elem, const int tidx,
45c532df63SYohann                                    const int tidy, const int comp,
46*288c0443SJeremy L Thompson                                    const int nelem, const CeedScalar &r_V,
47*288c0443SJeremy L Thompson                                    CeedScalar *d_V) {
48c532df63SYohann   if (tidx<P1D) {
49c532df63SYohann     d_V[tidx + comp*P1D + elem*BASIS_NCOMP*P1D] = r_V;
50c532df63SYohann   }
51c532df63SYohann }
52c532df63SYohann 
53c532df63SYohann inline __device__ void readQuads1d(const int elem, const int tidx,
54d94769d2SYohann Dudouit                                    const int tidy, const int tidz, const int comp,
55*288c0443SJeremy L Thompson                                    const int dim, const int nelem,
56*288c0443SJeremy L Thompson                                    const CeedScalar *d_U, CeedScalar *slice) {
57c532df63SYohann   for (int i = 0; i < Q1D; i++)
584d537eeaSYohann     slice[i+tidz*Q1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
594d537eeaSYohann                             dim*BASIS_NCOMP*nelem*Q1D];
60c532df63SYohann }
61c532df63SYohann 
62c532df63SYohann inline __device__ void writeQuads1d(const int elem, const int tidx,
63c532df63SYohann                                     const int tidy, const int comp,
64*288c0443SJeremy L Thompson                                     const int dim, const int nelem,
65*288c0443SJeremy L Thompson                                     const CeedScalar &r_V, CeedScalar *d_V) {
66c532df63SYohann   d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
67c532df63SYohann }
68c532df63SYohann 
69c532df63SYohann inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
70d94769d2SYohann Dudouit                                    const int tidy, const int tidz,
71*288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
72*288c0443SJeremy L Thompson                                    CeedScalar &V) {
73c532df63SYohann   V = 0.0;
74c532df63SYohann   for (int i = 0; i < P1D; ++i) {
75d94769d2SYohann Dudouit     V += B[i + tidx*P1D] * slice[i+tidz*Q1D];//contract x direction
76c532df63SYohann   }
77c532df63SYohann }
78c532df63SYohann 
79c532df63SYohann inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
80d94769d2SYohann Dudouit     const int tidy, const int tidz,
81c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
82c532df63SYohann   V = 0.0;
83c532df63SYohann   for (int i = 0; i < Q1D; ++i) {
84d94769d2SYohann Dudouit     V += B[tidx + i*P1D] * slice[i+tidz*Q1D];//contract x direction
85c532df63SYohann   }
86c532df63SYohann }
87c532df63SYohann 
88c532df63SYohann inline __device__ void interp1d(const CeedInt nelem, const int transpose,
89*288c0443SJeremy L Thompson                                 const CeedScalar *c_B,
90*288c0443SJeremy L Thompson                                 const CeedScalar *__restrict__ d_U,
91c532df63SYohann                                 CeedScalar *__restrict__ d_V,
92c532df63SYohann                                 CeedScalar *slice) {
93c532df63SYohann   CeedScalar r_V;
94c532df63SYohann   CeedScalar r_t;
95c532df63SYohann 
96c532df63SYohann   const int tidx = threadIdx.x;
97c532df63SYohann   const int tidy = threadIdx.y;
98d94769d2SYohann Dudouit   const int tidz = threadIdx.z;
99c532df63SYohann 
100c532df63SYohann 
101c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
102c532df63SYohann        elem += gridDim.x*blockDim.z) {
103c532df63SYohann     for(int comp=0; comp<BASIS_NCOMP; comp++) {
104c532df63SYohann       if(!transpose) {
105d94769d2SYohann Dudouit         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
106d94769d2SYohann Dudouit         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
107c532df63SYohann         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
108c532df63SYohann       } else {
109d94769d2SYohann Dudouit         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
110d94769d2SYohann Dudouit         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
111c532df63SYohann         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
112c532df63SYohann       }
113c532df63SYohann     }
114c532df63SYohann   }
115c532df63SYohann }
116c532df63SYohann 
117c532df63SYohann inline __device__ void grad1d(const CeedInt nelem, const int transpose,
118c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
119*288c0443SJeremy L Thompson                               const CeedScalar *__restrict__ d_U,
120*288c0443SJeremy L Thompson                               CeedScalar *__restrict__ d_V,
121c532df63SYohann                               CeedScalar *slice) {
122c532df63SYohann   CeedScalar r_U;
123c532df63SYohann   CeedScalar r_V;
124c532df63SYohann 
125c532df63SYohann   const int tidx = threadIdx.x;
126d94769d2SYohann Dudouit   const int tidy = threadIdx.y;
127d94769d2SYohann Dudouit   const int tidz = threadIdx.z;
128c532df63SYohann   int dim;
129c532df63SYohann 
130c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
131c532df63SYohann        elem += gridDim.x*blockDim.z) {
132c532df63SYohann     for(int comp=0; comp<BASIS_NCOMP; comp++) {
133c532df63SYohann       if(!transpose) {
134d94769d2SYohann Dudouit         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
135d94769d2SYohann Dudouit         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
136c532df63SYohann         dim = 0;
137c532df63SYohann         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
138c532df63SYohann       } else {
139c532df63SYohann         dim = 0;
140d94769d2SYohann Dudouit         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
141d94769d2SYohann Dudouit         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
142c532df63SYohann         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
143c532df63SYohann       }
144c532df63SYohann     }
145c532df63SYohann   }
146c532df63SYohann }
147c532df63SYohann //////////
148c532df63SYohann //  2D  //
149c532df63SYohann //////////
150c532df63SYohann 
151c532df63SYohann inline __device__ void readDofs2d(const int elem, const int tidx,
152c532df63SYohann                                   const int tidy, const int comp,
153*288c0443SJeremy L Thompson                                   const int nelem, const CeedScalar *d_U,
154*288c0443SJeremy L Thompson                                   CeedScalar &U) {
155c532df63SYohann   U = (tidx<P1D
156c532df63SYohann        && tidy<P1D) ? d_U[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D ] :
157c532df63SYohann       0.0;
158c532df63SYohann }
159c532df63SYohann 
160c532df63SYohann inline __device__ void writeDofs2d(const int elem, const int tidx,
161c532df63SYohann                                    const int tidy, const int comp,
162*288c0443SJeremy L Thompson                                    const int nelem, const CeedScalar &r_V,
163*288c0443SJeremy L Thompson                                    CeedScalar *d_V) {
164c532df63SYohann   if (tidx<P1D && tidy<P1D) {
165c532df63SYohann     d_V[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D ] = r_V;
166c532df63SYohann   }
167c532df63SYohann }
168c532df63SYohann 
169c532df63SYohann inline __device__ void readQuads2d(const int elem, const int tidx,
170c532df63SYohann                                    const int tidy, const int comp,
171*288c0443SJeremy L Thompson                                    const int dim, const int nelem,
172*288c0443SJeremy L Thompson                                    const CeedScalar *d_U, CeedScalar &U ) {
173c532df63SYohann   U = d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
174c532df63SYohann                dim*BASIS_NCOMP*nelem*Q1D*Q1D];
175c532df63SYohann }
176c532df63SYohann 
177c532df63SYohann inline __device__ void writeQuads2d(const int elem, const int tidx,
178c532df63SYohann                                     const int tidy, const int comp,
179*288c0443SJeremy L Thompson                                     const int dim, const int nelem,
180*288c0443SJeremy L Thompson                                     const CeedScalar &r_V, CeedScalar *d_V) {
181c532df63SYohann   d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
182c532df63SYohann            dim*BASIS_NCOMP*nelem*Q1D*Q1D ] = r_V;
183c532df63SYohann }
184c532df63SYohann 
185c532df63SYohann inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
1864247ecf3SYohann Dudouit                                    const int tidy, const int tidz,
187*288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
188*288c0443SJeremy L Thompson                                    CeedScalar &V) {
1894247ecf3SYohann Dudouit   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
190c532df63SYohann   __syncthreads();
191c532df63SYohann   V = 0.0;
192c532df63SYohann   for (int i = 0; i < P1D; ++i) {
1934247ecf3SYohann Dudouit     V += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction
194c532df63SYohann   }
195c532df63SYohann   __syncthreads();
196c532df63SYohann }
197c532df63SYohann 
198c532df63SYohann inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
1994247ecf3SYohann Dudouit                                    const int tidy, const int tidz,
200*288c0443SJeremy L Thompson                                    const CeedScalar &U, const CeedScalar *B,
201*288c0443SJeremy L Thompson                                    CeedScalar &V) {
2024247ecf3SYohann Dudouit   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
203c532df63SYohann   __syncthreads();
204c532df63SYohann   V = 0.0;
205c532df63SYohann   for (int i = 0; i < P1D; ++i) {
2064247ecf3SYohann Dudouit     V += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction
207c532df63SYohann   }
208c532df63SYohann   __syncthreads();
209c532df63SYohann }
210c532df63SYohann 
211c532df63SYohann inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
2124247ecf3SYohann Dudouit     const int tidy, const int tidz,
213c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
2144247ecf3SYohann Dudouit   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
215c532df63SYohann   __syncthreads();
216c532df63SYohann   V = 0.0;
217c532df63SYohann   if (tidy<P1D) {
218c532df63SYohann     for (int i = 0; i < Q1D; ++i) {
2194247ecf3SYohann Dudouit       V += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction
220c532df63SYohann     }
221c532df63SYohann   }
222c532df63SYohann   __syncthreads();
223c532df63SYohann }
224c532df63SYohann 
225c532df63SYohann inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
2264247ecf3SYohann Dudouit     const int tidy, const int tidz,
227c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
2284247ecf3SYohann Dudouit   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
229c532df63SYohann   __syncthreads();
230c532df63SYohann   V = 0.0;
231c532df63SYohann   if (tidx<P1D) {
232c532df63SYohann     for (int i = 0; i < Q1D; ++i) {
2334247ecf3SYohann Dudouit       V += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction
234c532df63SYohann     }
235c532df63SYohann   }
236c532df63SYohann   __syncthreads();
237c532df63SYohann }
238c532df63SYohann 
239c532df63SYohann inline __device__ void interp2d(const CeedInt nelem, const int transpose,
240*288c0443SJeremy L Thompson                                 const CeedScalar *c_B,
241*288c0443SJeremy L Thompson                                 const CeedScalar *__restrict__ d_U,
242c532df63SYohann                                 CeedScalar *__restrict__ d_V,
243c532df63SYohann                                 CeedScalar *slice) {
244c532df63SYohann   CeedScalar r_V;
245c532df63SYohann   CeedScalar r_t;
246c532df63SYohann 
247c532df63SYohann   const int tidx = threadIdx.x;
248c532df63SYohann   const int tidy = threadIdx.y;
2494247ecf3SYohann Dudouit   const int tidz = threadIdx.z;
2504247ecf3SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
2514247ecf3SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
2524247ecf3SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
253c532df63SYohann 
2544247ecf3SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
2554247ecf3SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
2564247ecf3SYohann Dudouit     const int comp = tidz%BASIS_NCOMP;
257c532df63SYohann     r_V = 0.0;
258c532df63SYohann     r_t = 0.0;
259c532df63SYohann     if(!transpose) {
260c532df63SYohann       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
2614247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
2624247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
263c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
264c532df63SYohann     } else {
265c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
2664247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
2674247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
268c532df63SYohann       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
269c532df63SYohann     }
270c532df63SYohann   }
271c532df63SYohann }
272c532df63SYohann 
273c532df63SYohann inline __device__ void grad2d(const CeedInt nelem, const int transpose,
274c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
275c532df63SYohann                               const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V,
276c532df63SYohann                               CeedScalar *slice) {
277c532df63SYohann   CeedScalar r_U;
278c532df63SYohann   CeedScalar r_V;
279c532df63SYohann   CeedScalar r_t;
280c532df63SYohann 
281c532df63SYohann   const int tidx = threadIdx.x;
282c532df63SYohann   const int tidy = threadIdx.y;
2834247ecf3SYohann Dudouit   const int tidz = threadIdx.z;
2844247ecf3SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
2854247ecf3SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
2864247ecf3SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
287c532df63SYohann   int dim;
288c532df63SYohann 
2894247ecf3SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
2904247ecf3SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
291c532df63SYohann     if(!transpose) {
292c532df63SYohann       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
2934247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
2944247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
295c532df63SYohann       dim = 0;
296c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
2974247ecf3SYohann Dudouit       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
2984247ecf3SYohann Dudouit       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
299c532df63SYohann       dim = 1;
300c532df63SYohann       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
301c532df63SYohann     } else {
302c532df63SYohann       dim = 0;
303c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
3044247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
3054247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
306c532df63SYohann       dim = 1;
307c532df63SYohann       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
3084247ecf3SYohann Dudouit       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
3094247ecf3SYohann Dudouit       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
310c532df63SYohann       r_V+=r_U;
311c532df63SYohann       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
312c532df63SYohann     }
313c532df63SYohann   }
314c532df63SYohann }
315c532df63SYohann //////////
316c532df63SYohann //  3D  //
317c532df63SYohann //////////
318c532df63SYohann 
319c532df63SYohann inline __device__ void readDofs3d(const int elem, const int tidx,
320c532df63SYohann                                   const int tidy, const int comp,
321c532df63SYohann                                   const int nelem, const CeedScalar *d_U, CeedScalar *r_U) {
322c532df63SYohann   for (int i = 0; i < P1D; i++)
323c532df63SYohann     r_U[i] = (tidx<P1D
324c532df63SYohann               && tidy<P1D) ? d_U[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
325c532df63SYohann                                       elem*BASIS_NCOMP*P1D*P1D*P1D ] : 0.0;
326c532df63SYohann   for (int i = P1D; i < Q1D; i++)
327c532df63SYohann     r_U[i] = 0.0;
328c532df63SYohann }
329c532df63SYohann 
330c532df63SYohann inline __device__ void readQuads3d(const int elem, const int tidx,
331c532df63SYohann                                    const int tidy, const int comp,
332c532df63SYohann                                    const int dim, const int nelem, const CeedScalar *d_U, CeedScalar *r_U) {
333c532df63SYohann   for (int i = 0; i < Q1D; i++)
334c532df63SYohann     r_U[i] = d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
335c532df63SYohann                  comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D];
336c532df63SYohann }
337c532df63SYohann 
338c532df63SYohann inline __device__ void writeDofs3d(const int elem, const int tidx,
339c532df63SYohann                                    const int tidy, const int comp,
340c532df63SYohann                                    const int nelem, const CeedScalar *r_V, CeedScalar *d_V) {
341c532df63SYohann   if (tidx<P1D && tidy<P1D) {
342c532df63SYohann     for (int i = 0; i < P1D; i++)
343c532df63SYohann       d_V[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
344c532df63SYohann           elem*BASIS_NCOMP*P1D*P1D*P1D ] = r_V[i];
345c532df63SYohann   }
346c532df63SYohann }
347c532df63SYohann 
348c532df63SYohann inline __device__ void writeQuads3d(const int elem, const int tidx,
349c532df63SYohann                                     const int tidy, const int comp,
350c532df63SYohann                                     const int dim, const int nelem, const CeedScalar *r_V, CeedScalar *d_V) {
351c532df63SYohann   for (int i = 0; i < Q1D; i++)
352c532df63SYohann     d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
353c532df63SYohann         dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D ] = r_V[i];
354c532df63SYohann }
355c532df63SYohann 
356c532df63SYohann inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
357698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
358c532df63SYohann                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
359c532df63SYohann   for (int k = 0; k < P1D; ++k) {
360698ebc35SYohann Dudouit     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
361c532df63SYohann     __syncthreads();
362c532df63SYohann     V[k] = 0.0;
363c532df63SYohann     for (int i = 0; i < P1D; ++i) {
3644d537eeaSYohann       V[k] += B[i + tidx*P1D] * slice[i + tidy*Q1D +
3654d537eeaSYohann                                       tidz*Q1D*Q1D];//contract x direction
366c532df63SYohann     }
367c532df63SYohann     __syncthreads();
368c532df63SYohann   }
369c532df63SYohann }
370c532df63SYohann 
371c532df63SYohann inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
372698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
373c532df63SYohann                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
374c532df63SYohann   for (int k = 0; k < P1D; ++k) {
375698ebc35SYohann Dudouit     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
376c532df63SYohann     __syncthreads();
377c532df63SYohann     V[k] = 0.0;
378c532df63SYohann     for (int i = 0; i < P1D; ++i) {
3794d537eeaSYohann       V[k] += B[i + tidy*P1D] * slice[tidx + i*Q1D +
3804d537eeaSYohann                                       tidz*Q1D*Q1D];//contract y direction
381c532df63SYohann     }
382c532df63SYohann     __syncthreads();
383c532df63SYohann   }
384c532df63SYohann }
385c532df63SYohann 
386c532df63SYohann inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
387698ebc35SYohann Dudouit                                    const int tidy, const int tidz,
388c532df63SYohann                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
389c532df63SYohann   for (int k = 0; k < Q1D; ++k) {
390c532df63SYohann     V[k] = 0.0;
391c532df63SYohann     for (int i = 0; i < P1D; ++i) {
392c532df63SYohann       V[k] += B[i + k*P1D] * U[i];//contract z direction
393c532df63SYohann     }
394c532df63SYohann   }
395c532df63SYohann }
396c532df63SYohann 
397c532df63SYohann inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
398698ebc35SYohann Dudouit     const int tidy, const int tidz,
399c532df63SYohann     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
400c532df63SYohann   for (int k = 0; k < Q1D; ++k) {
401c532df63SYohann     V[k] = 0.0;
402c532df63SYohann     if (k<P1D) {
403c532df63SYohann       for (int i = 0; i < Q1D; ++i) {
404c532df63SYohann         V[k] += B[k + i*P1D] * U[i];//contract z direction
405c532df63SYohann       }
406c532df63SYohann     }
407c532df63SYohann   }
408c532df63SYohann }
409c532df63SYohann 
410c532df63SYohann inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
411698ebc35SYohann Dudouit     const int tidy, const int tidz,
412c532df63SYohann     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
413c532df63SYohann   for (int k = 0; k < P1D; ++k) {
414698ebc35SYohann Dudouit     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
415c532df63SYohann     __syncthreads();
416c532df63SYohann     V[k] = 0.0;
417c532df63SYohann     if (tidy<P1D) {
418c532df63SYohann       for (int i = 0; i < Q1D; ++i) {
4194d537eeaSYohann         V[k] += B[tidy + i*P1D] * slice[tidx + i*Q1D +
4204d537eeaSYohann                                         tidz*Q1D*Q1D];//contract y direction
421c532df63SYohann       }
422c532df63SYohann     }
423c532df63SYohann     __syncthreads();
424c532df63SYohann   }
425c532df63SYohann }
426c532df63SYohann 
427c532df63SYohann inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
428698ebc35SYohann Dudouit     const int tidy, const int tidz,
429c532df63SYohann     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
430c532df63SYohann   for (int k = 0; k < P1D; ++k) {
431698ebc35SYohann Dudouit     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
432c532df63SYohann     __syncthreads();
433c532df63SYohann     V[k] = 0.0;
434c532df63SYohann     if (tidx<P1D) {
435c532df63SYohann       for (int i = 0; i < Q1D; ++i) {
4364d537eeaSYohann         V[k] += B[tidx + i*P1D] * slice[i + tidy*Q1D +
4374d537eeaSYohann                                         tidz*Q1D*Q1D];//contract x direction
438c532df63SYohann       }
439c532df63SYohann     }
440c532df63SYohann     __syncthreads();
441c532df63SYohann   }
442c532df63SYohann }
443c532df63SYohann 
444c532df63SYohann inline __device__ void interp3d(const CeedInt nelem, const int transpose,
445c532df63SYohann                                 const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
446c532df63SYohann                                 CeedScalar *__restrict__ d_V,
447c532df63SYohann                                 CeedScalar *slice) {
448c532df63SYohann   CeedScalar r_V[Q1D];
449c532df63SYohann   CeedScalar r_t[Q1D];
450c532df63SYohann 
451c532df63SYohann   const int tidx = threadIdx.x;
452c532df63SYohann   const int tidy = threadIdx.y;
453698ebc35SYohann Dudouit   const int tidz = threadIdx.z;
454698ebc35SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
455698ebc35SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
456698ebc35SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
457c532df63SYohann 
458698ebc35SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
459698ebc35SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
460c532df63SYohann     for (int i = 0; i < Q1D; ++i) {
461c532df63SYohann       r_V[i] = 0.0;
462c532df63SYohann       r_t[i] = 0.0;
463c532df63SYohann     }
464c532df63SYohann     if(!transpose) {
465c532df63SYohann       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
466698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
467698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
468698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
469c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
470c532df63SYohann     } else {
471c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
472698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
473698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
474698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
475c532df63SYohann       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
476c532df63SYohann     }
477c532df63SYohann   }
478c532df63SYohann }
479c532df63SYohann 
480c532df63SYohann inline __device__ void grad3d(const CeedInt nelem, const int transpose,
481c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
482c532df63SYohann                               const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V,
483c532df63SYohann                               CeedScalar *slice) {
484c532df63SYohann   //use P1D for one of these
485c532df63SYohann   CeedScalar r_U[Q1D];
486c532df63SYohann   CeedScalar r_V[Q1D];
487c532df63SYohann   CeedScalar r_t[Q1D];
488c532df63SYohann 
489c532df63SYohann   const int tidx = threadIdx.x;
490c532df63SYohann   const int tidy = threadIdx.y;
491698ebc35SYohann Dudouit   const int tidz = threadIdx.z;
492698ebc35SYohann Dudouit   const int blockElem = tidz/BASIS_NCOMP;
493698ebc35SYohann Dudouit   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
494698ebc35SYohann Dudouit   const int comp = tidz%BASIS_NCOMP;
495c532df63SYohann   int dim;
496c532df63SYohann 
497698ebc35SYohann Dudouit   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
498698ebc35SYohann Dudouit        elem += gridDim.x*elemsPerBlock) {
499c532df63SYohann     if(!transpose) {
500c532df63SYohann       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
501698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
502698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
503698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
504c532df63SYohann       dim = 0;
505c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
506698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
507698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
508698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
509c532df63SYohann       dim = 1;
510c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
511698ebc35SYohann Dudouit       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
512698ebc35SYohann Dudouit       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
513698ebc35SYohann Dudouit       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
514c532df63SYohann       dim = 2;
515c532df63SYohann       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
516c532df63SYohann     } else {
517c532df63SYohann       dim = 0;
518c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
519698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
520698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
521698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
522c532df63SYohann       dim = 1;
523c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
524698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
525698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
526698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
527c532df63SYohann       add(r_V, r_t);
528c532df63SYohann       dim = 2;
529c532df63SYohann       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
530698ebc35SYohann Dudouit       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
531698ebc35SYohann Dudouit       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
532698ebc35SYohann Dudouit       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
533c532df63SYohann       add(r_V, r_t);
534c532df63SYohann       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
535c532df63SYohann     }
536c532df63SYohann   }
537c532df63SYohann }
538c532df63SYohann 
539c532df63SYohann /////////////
540c532df63SYohann // Kernels //
541c532df63SYohann /////////////
542c532df63SYohann extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
543c532df63SYohann                                   const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
544c532df63SYohann                                   CeedScalar *__restrict__ d_V) {
545074be161SYohann Dudouit   extern __shared__ double slice[];
546c532df63SYohann   if (BASIS_DIM==1) {
547c532df63SYohann     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
548c532df63SYohann   } else if (BASIS_DIM==2) {
549c532df63SYohann     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
550c532df63SYohann   } else if (BASIS_DIM==3) {
551c532df63SYohann     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
552c532df63SYohann   }
553c532df63SYohann }
554c532df63SYohann 
555c532df63SYohann extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
556c532df63SYohann                                 const CeedScalar *c_B, const CeedScalar *c_G,
557c532df63SYohann                                 const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) {
558074be161SYohann Dudouit   extern __shared__ double slice[];
559c532df63SYohann   if (BASIS_DIM==1) {
560c532df63SYohann     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
561c532df63SYohann   } else if (BASIS_DIM==2) {
562c532df63SYohann     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
563c532df63SYohann   } else if (BASIS_DIM==3) {
564c532df63SYohann     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
565c532df63SYohann   }
566c532df63SYohann }
567c532df63SYohann 
568c532df63SYohann /////////////
569c532df63SYohann // Weights //
570c532df63SYohann /////////////
571c532df63SYohann __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
572c532df63SYohann                          CeedScalar *w) {
573074be161SYohann Dudouit   const int tid = threadIdx.x;
574074be161SYohann Dudouit   const CeedScalar weight = qweight1d[tid];
575074be161SYohann Dudouit   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
576074be161SYohann Dudouit        elem += gridDim.x*blockDim.y) {
577074be161SYohann Dudouit     const int ind = elem*Q1D + tid;
578074be161SYohann Dudouit     w[ind] = weight;
579c532df63SYohann   }
580c532df63SYohann }
581c532df63SYohann 
582c532df63SYohann __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
583c532df63SYohann                          CeedScalar *w) {
584074be161SYohann Dudouit   const int i = threadIdx.x;
585074be161SYohann Dudouit   const int j = threadIdx.y;
586074be161SYohann Dudouit   const CeedScalar weight = qweight1d[i]*qweight1d[j];
587074be161SYohann Dudouit   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
588074be161SYohann Dudouit        elem += gridDim.x*blockDim.z) {
589074be161SYohann Dudouit     const int ind = elem*Q1D*Q1D + i + j*Q1D;
590074be161SYohann Dudouit     w[ind] = weight;
591c532df63SYohann   }
592c532df63SYohann }
593c532df63SYohann 
594c532df63SYohann __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
595c532df63SYohann                          CeedScalar *w) {
596074be161SYohann Dudouit   const int i = threadIdx.x;
597074be161SYohann Dudouit   const int j = threadIdx.y;
598074be161SYohann Dudouit   const int k = threadIdx.z;
599074be161SYohann Dudouit   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
600074be161SYohann Dudouit   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
601074be161SYohann Dudouit     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
602074be161SYohann Dudouit     w[ind] = weight;
603c532df63SYohann   }
604c532df63SYohann }
605c532df63SYohann 
606c532df63SYohann extern "C" __global__ void weight(const CeedInt nelem,
607c532df63SYohann                                   const CeedScalar *__restrict__ qweight1d, CeedScalar *__restrict__ v) {
608c532df63SYohann   if (BASIS_DIM==1) {
609c532df63SYohann     weight1d(nelem, qweight1d, v);
610c532df63SYohann   } else if (BASIS_DIM==2) {
611c532df63SYohann     weight2d(nelem, qweight1d, v);
612c532df63SYohann   } else if (BASIS_DIM==3) {
613c532df63SYohann     weight3d(nelem, qweight1d, v);
614c532df63SYohann   }
615c532df63SYohann }
616c532df63SYohann 
617c532df63SYohann );
618c532df63SYohann 
619c532df63SYohann int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
620c532df63SYohann                        CeedScalar **c_B);
621c532df63SYohann int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
622c532df63SYohann                            CeedInt Q1d, CeedScalar **c_B_ptr, CeedScalar **c_G_ptr);
623c532df63SYohann 
624c532df63SYohann int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem,
625c532df63SYohann                                      CeedTransposeMode tmode,
626c532df63SYohann                                      CeedEvalMode emode, CeedVector u, CeedVector v) {
627c532df63SYohann   int ierr;
628c532df63SYohann   Ceed ceed;
629c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
630c532df63SYohann   Ceed_Cuda_shared *ceed_Cuda;
631c532df63SYohann   CeedGetData(ceed, (void *) &ceed_Cuda); CeedChk(ierr);
632c532df63SYohann   CeedBasis_Cuda_shared *data;
633c532df63SYohann   CeedBasisGetData(basis, (void *)&data); CeedChk(ierr);
634c532df63SYohann   const CeedInt transpose = tmode == CEED_TRANSPOSE;
6354247ecf3SYohann Dudouit   CeedInt dim, ncomp;
636074be161SYohann Dudouit   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
6374247ecf3SYohann Dudouit   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
638c532df63SYohann 
639c532df63SYohann   const CeedScalar *d_u;
640c532df63SYohann   CeedScalar *d_v;
641c532df63SYohann   if(emode!=CEED_EVAL_WEIGHT) {
642c532df63SYohann     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
643c532df63SYohann   }
644c532df63SYohann   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
645c532df63SYohann 
646c532df63SYohann   if (tmode == CEED_TRANSPOSE) {
647c532df63SYohann     CeedInt length;
648c532df63SYohann     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
649c532df63SYohann     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
650c532df63SYohann   }
651c532df63SYohann   if (emode == CEED_EVAL_INTERP) {
652c532df63SYohann     CeedInt P1d, Q1d;
653c532df63SYohann     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
654c532df63SYohann     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
655c532df63SYohann     ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
656c532df63SYohann     CeedChk(ierr);
657c532df63SYohann     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, &d_u, &d_v};
6584d537eeaSYohann     if (dim==1) {
659d94769d2SYohann Dudouit       CeedInt elemsPerBlock = 32;
6604d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
6614d537eeaSYohann                                              ? 1 : 0 );
662d94769d2SYohann Dudouit       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
6634d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, 1,
6644d537eeaSYohann                                         elemsPerBlock, sharedMem,
665c532df63SYohann                                         interpargs);
666c532df63SYohann       CeedChk(ierr);
667074be161SYohann Dudouit     } else if (dim==2) {
6684247ecf3SYohann Dudouit       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
6694247ecf3SYohann Dudouit       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
6704d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
6714d537eeaSYohann                                              ? 1 : 0 );
6724247ecf3SYohann Dudouit       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
6734d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d,
6744d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
675074be161SYohann Dudouit                                         interpargs);
676074be161SYohann Dudouit       CeedChk(ierr);
677074be161SYohann Dudouit     } else if (dim==3) {
6783f63d318SYohann Dudouit       CeedInt elemsPerBlock = 1;
6794d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
6804d537eeaSYohann                                              ? 1 : 0 );
681698ebc35SYohann Dudouit       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
6824d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d,
6834d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
684074be161SYohann Dudouit                                         interpargs);
685074be161SYohann Dudouit       CeedChk(ierr);
686074be161SYohann Dudouit     }
687c532df63SYohann   } else if (emode == CEED_EVAL_GRAD) {
688c532df63SYohann     CeedInt P1d, Q1d;
689c532df63SYohann     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
690c532df63SYohann     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
691c532df63SYohann     ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
692c532df63SYohann                                   Q1d, &data->c_B, &data->c_G);
693c532df63SYohann     CeedChk(ierr);
694c532df63SYohann     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, &data->c_G, &d_u, &d_v};
6954d537eeaSYohann     if (dim==1) {
696d94769d2SYohann Dudouit       CeedInt elemsPerBlock = 32;
6974d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
6984d537eeaSYohann                                              ? 1 : 0 );
699d94769d2SYohann Dudouit       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
7004d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, 1, elemsPerBlock,
7014d537eeaSYohann                                         sharedMem,
702c532df63SYohann                                         gradargs);
703c532df63SYohann       CeedChk(ierr);
704074be161SYohann Dudouit     } else if (dim==2) {
7054247ecf3SYohann Dudouit       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
7064247ecf3SYohann Dudouit       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
7074d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
7084d537eeaSYohann                                              ? 1 : 0 );
7094247ecf3SYohann Dudouit       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
7104d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d,
7114d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
712074be161SYohann Dudouit                                         gradargs);
713074be161SYohann Dudouit       CeedChk(ierr);
714074be161SYohann Dudouit     } else if (dim==3) {
7153f63d318SYohann Dudouit       CeedInt elemsPerBlock = 1;
7164d537eeaSYohann       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
7174d537eeaSYohann                                              ? 1 : 0 );
718698ebc35SYohann Dudouit       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
7194d537eeaSYohann       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d,
7204d537eeaSYohann                                         ncomp*elemsPerBlock, sharedMem,
721074be161SYohann Dudouit                                         gradargs);
722074be161SYohann Dudouit       CeedChk(ierr);
723074be161SYohann Dudouit     }
724c532df63SYohann   } else if (emode == CEED_EVAL_WEIGHT) {
725074be161SYohann Dudouit     CeedInt Q1d;
726074be161SYohann Dudouit     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
727c532df63SYohann     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
728074be161SYohann Dudouit     if(dim==1) {
729074be161SYohann Dudouit       const CeedInt elemsPerBlock = 32/Q1d;
7304d537eeaSYohann       const CeedInt gridsize = nelem/elemsPerBlock + ( (
7314d537eeaSYohann                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
7324d537eeaSYohann       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, elemsPerBlock, 1,
7334d537eeaSYohann                                   weightargs);
7341226057fSYohann Dudouit       CeedChk(ierr);
735074be161SYohann Dudouit     } else if(dim==2) {
736717ff8a3SYohann Dudouit       const CeedInt optElems = 32/(Q1d*Q1d);
737717ff8a3SYohann Dudouit       const CeedInt elemsPerBlock = optElems>0?optElems:1;
7384d537eeaSYohann       const CeedInt gridsize = nelem/elemsPerBlock + ( (
7394d537eeaSYohann                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
7404d537eeaSYohann       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d,
7414d537eeaSYohann                                   elemsPerBlock, weightargs);
7421226057fSYohann Dudouit       CeedChk(ierr);
743074be161SYohann Dudouit     } else if(dim==3) {
744074be161SYohann Dudouit       const CeedInt gridsize = nelem;
7454d537eeaSYohann       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
7464d537eeaSYohann                                   weightargs);
7471226057fSYohann Dudouit       CeedChk(ierr);
748074be161SYohann Dudouit     }
749c532df63SYohann   }
750c532df63SYohann 
751c532df63SYohann   if(emode!=CEED_EVAL_WEIGHT) {
752c532df63SYohann     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
753c532df63SYohann   }
754c532df63SYohann   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
755c532df63SYohann 
756c532df63SYohann   return 0;
757c532df63SYohann }
758c532df63SYohann 
759c532df63SYohann static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
760c532df63SYohann   int ierr;
761c532df63SYohann   Ceed ceed;
762c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
763c532df63SYohann 
764c532df63SYohann   CeedBasis_Cuda_shared *data;
765c532df63SYohann   ierr = CeedBasisGetData(basis, (void *) &data); CeedChk(ierr);
766c532df63SYohann 
767c532df63SYohann   CeedChk_Cu(ceed, cuModuleUnload(data->module));
768c532df63SYohann 
769c532df63SYohann   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr);
770c532df63SYohann   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr);
771c532df63SYohann   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr);
772c532df63SYohann 
773c532df63SYohann   ierr = CeedFree(&data); CeedChk(ierr);
774c532df63SYohann 
775c532df63SYohann   return 0;
776c532df63SYohann }
777c532df63SYohann 
778c532df63SYohann int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
779c532df63SYohann                                         const CeedScalar *interp1d,
780c532df63SYohann                                         const CeedScalar *grad1d,
781c532df63SYohann                                         const CeedScalar *qref1d,
782c532df63SYohann                                         const CeedScalar *qweight1d,
783c532df63SYohann                                         CeedBasis basis) {
784c532df63SYohann   int ierr;
785c532df63SYohann   Ceed ceed;
786c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
7874d537eeaSYohann   if (Q1d<P1d) {
7881226057fSYohann Dudouit     return CeedError(ceed, 1, "Backend does not implement underintegrated basis.");
7891226057fSYohann Dudouit   }
790c532df63SYohann   CeedBasis_Cuda_shared *data;
791c532df63SYohann   ierr = CeedCalloc(1, &data); CeedChk(ierr);
792c532df63SYohann 
793c532df63SYohann   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
794c532df63SYohann   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr);
795c532df63SYohann   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
796c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
797c532df63SYohann 
798c532df63SYohann   const CeedInt iBytes = qBytes * P1d;
799c532df63SYohann   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr);
800c532df63SYohann   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
801c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
802c532df63SYohann 
803c532df63SYohann   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr);
804c532df63SYohann   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
805c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
806c532df63SYohann 
807c532df63SYohann   CeedInt ncomp;
808c532df63SYohann   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
8094a6d4bbdSYohann Dudouit   ierr = CeedCompileCuda(ceed, kernelsShared, &data->module, 7,
810c532df63SYohann                          "Q1D", Q1d,
811c532df63SYohann                          "P1D", P1d,
812c532df63SYohann                          "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
813c532df63SYohann                              Q1d : P1d, dim),
814c532df63SYohann                          "BASIS_DIM", dim,
815c532df63SYohann                          "BASIS_NCOMP", ncomp,
816c532df63SYohann                          "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
817c532df63SYohann                          "BASIS_NQPT", CeedIntPow(Q1d, dim)
818c532df63SYohann                         ); CeedChk(ierr);
8194a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp);
820c532df63SYohann   CeedChk(ierr);
8214a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad);
822c532df63SYohann   CeedChk(ierr);
8234a6d4bbdSYohann Dudouit   ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight);
824c532df63SYohann   CeedChk(ierr);
825c532df63SYohann 
826c532df63SYohann   ierr = CeedBasisSetData(basis, (void *)&data);
827c532df63SYohann   CeedChk(ierr);
828c532df63SYohann   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
829c532df63SYohann                                 CeedBasisApplyTensor_Cuda_shared);
830c532df63SYohann   CeedChk(ierr);
831c532df63SYohann   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
832c532df63SYohann                                 CeedBasisDestroy_Cuda_shared);
833c532df63SYohann   CeedChk(ierr);
834c532df63SYohann   return 0;
835c532df63SYohann }
836