xref: /libCEED/rust/libceed-sys/c-src/backends/cuda-shared/ceed-cuda-shared-basis.c (revision 074be161bac8d8f2ff6efdceafa0bbdf1835071b)
1c532df63SYohann // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2c532df63SYohann // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3c532df63SYohann // All Rights reserved. See files LICENSE and NOTICE for details.
4c532df63SYohann //
5c532df63SYohann // This file is part of CEED, a collection of benchmarks, miniapps, software
6c532df63SYohann // libraries and APIs for efficient high-order finite element and spectral
7c532df63SYohann // element discretizations for exascale applications. For more information and
8c532df63SYohann // source code availability see http://github.com/ceed.
9c532df63SYohann //
10c532df63SYohann // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11c532df63SYohann // a collaborative effort of two U.S. Department of Energy organizations (Office
12c532df63SYohann // of Science and the National Nuclear Security Administration) responsible for
13c532df63SYohann // the planning and preparation of a capable exascale ecosystem, including
14c532df63SYohann // software, applications, hardware, advanced system engineering and early
15c532df63SYohann // testbed platforms, in support of the nation's exascale computing imperative.
16c532df63SYohann 
17c532df63SYohann #include <ceed-backend.h>
18c532df63SYohann #include <ceed.h>
19c532df63SYohann #include "ceed-cuda-shared.h"
20c532df63SYohann #include "../cuda/ceed-cuda.h"
21c532df63SYohann 
22c532df63SYohann //*********************
23c532df63SYohann // shared mem kernels
24c532df63SYohann static const char *kernelsShared = QUOTE(
25c532df63SYohann 
26c532df63SYohann inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
27c532df63SYohann   for (int i = 0; i < Q1D; i++)
28c532df63SYohann     r_V[i] += r_U[i];
29c532df63SYohann }
30c532df63SYohann 
31c532df63SYohann //////////
32c532df63SYohann //  1D  //
33c532df63SYohann //////////
34c532df63SYohann 
35c532df63SYohann inline __device__ void readDofs1d(const int elem, const int tidx,
36c532df63SYohann                                   const int tidy, const int comp,
37c532df63SYohann                                   const int nelem, const CeedScalar *d_U, CeedScalar *slice) {
38c532df63SYohann   for (int i = 0; i < P1D; i++)
39c532df63SYohann     slice[i] = d_U[i + comp*P1D + elem*BASIS_NCOMP*P1D];
40c532df63SYohann   for (int i = P1D; i < Q1D; i++)
41c532df63SYohann     slice[i] = 0.0;
42c532df63SYohann }
43c532df63SYohann 
44c532df63SYohann inline __device__ void writeDofs1d(const int elem, const int tidx,
45c532df63SYohann                                    const int tidy, const int comp,
46c532df63SYohann                                    const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
47c532df63SYohann   if (tidx<P1D) {
48c532df63SYohann     d_V[tidx + comp*P1D + elem*BASIS_NCOMP*P1D] = r_V;
49c532df63SYohann   }
50c532df63SYohann }
51c532df63SYohann 
52c532df63SYohann inline __device__ void readQuads1d(const int elem, const int tidx,
53c532df63SYohann                                    const int tidy, const int comp,
54c532df63SYohann                                    const int dim, const int nelem, const CeedScalar *d_U, CeedScalar *slice) {
55c532df63SYohann   for (int i = 0; i < Q1D; i++)
56c532df63SYohann     slice[i] = d_U[i + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D];
57c532df63SYohann }
58c532df63SYohann 
59c532df63SYohann inline __device__ void writeQuads1d(const int elem, const int tidx,
60c532df63SYohann                                     const int tidy, const int comp,
61c532df63SYohann                                     const int dim, const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
62c532df63SYohann   d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
63c532df63SYohann }
64c532df63SYohann 
65c532df63SYohann inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
66c532df63SYohann                                    const int tidy,
67c532df63SYohann                                    const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
68c532df63SYohann   V = 0.0;
69c532df63SYohann   for (int i = 0; i < P1D; ++i) {
70c532df63SYohann     V += B[i + tidx*P1D] * slice[i];//contract x direction
71c532df63SYohann   }
72c532df63SYohann }
73c532df63SYohann 
74c532df63SYohann inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
75c532df63SYohann     const int tidy,
76c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
77c532df63SYohann   V = 0.0;
78c532df63SYohann   for (int i = 0; i < Q1D; ++i) {
79c532df63SYohann     V += B[tidx + i*P1D] * slice[i];//contract x direction
80c532df63SYohann   }
81c532df63SYohann }
82c532df63SYohann 
83c532df63SYohann inline __device__ void interp1d(const CeedInt nelem, const int transpose,
84c532df63SYohann                                 const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
85c532df63SYohann                                 CeedScalar *__restrict__ d_V,
86c532df63SYohann                                 CeedScalar *slice) {
87c532df63SYohann   CeedScalar r_V;
88c532df63SYohann   CeedScalar r_t;
89c532df63SYohann 
90c532df63SYohann   const int tidx = threadIdx.x;
91c532df63SYohann   const int tidy = threadIdx.y;
92c532df63SYohann 
93c532df63SYohann 
94c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
95c532df63SYohann        elem += gridDim.x*blockDim.z) {
96c532df63SYohann     for(int comp=0; comp<BASIS_NCOMP; comp++) {
97c532df63SYohann       if(!transpose) {
98c532df63SYohann         readDofs1d(elem, tidx, tidy, comp, nelem, d_U, slice);
99c532df63SYohann         ContractX1d(slice, tidx, tidy, r_t, c_B, r_V);
100c532df63SYohann         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
101c532df63SYohann       } else {
102c532df63SYohann         readQuads1d(elem, tidx, tidy, comp, 0, nelem, d_U, slice);
103c532df63SYohann         ContractTransposeX1d(slice, tidx, tidy, r_t, c_B, r_V);
104c532df63SYohann         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
105c532df63SYohann       }
106c532df63SYohann     }
107c532df63SYohann   }
108c532df63SYohann }
109c532df63SYohann 
110c532df63SYohann inline __device__ void grad1d(const CeedInt nelem, const int transpose,
111c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
112c532df63SYohann                               const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V,
113c532df63SYohann                               CeedScalar *slice) {
114c532df63SYohann   CeedScalar r_U;
115c532df63SYohann   CeedScalar r_V;
116c532df63SYohann 
117c532df63SYohann   const int tidx = threadIdx.x;
118c532df63SYohann   const int tidy = threadIdx.y;//=>this is really a nb of elements per block
119c532df63SYohann   int dim;
120c532df63SYohann 
121c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
122c532df63SYohann        elem += gridDim.x*blockDim.z) {
123c532df63SYohann     for(int comp=0; comp<BASIS_NCOMP; comp++) {
124c532df63SYohann       if(!transpose) {
125c532df63SYohann         readDofs1d(elem, tidx, tidy, comp, nelem, d_U, slice);
126c532df63SYohann         ContractX1d(slice, tidx, tidy, r_U, c_G, r_V);
127c532df63SYohann         dim = 0;
128c532df63SYohann         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
129c532df63SYohann       } else {
130c532df63SYohann         dim = 0;
131c532df63SYohann         readQuads1d(elem, tidx, tidy, comp, dim, nelem, d_U, slice);
132c532df63SYohann         ContractTransposeX1d(slice, tidx, tidy, r_U, c_G, r_V);
133c532df63SYohann         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
134c532df63SYohann       }
135c532df63SYohann     }
136c532df63SYohann   }
137c532df63SYohann }
138c532df63SYohann //////////
139c532df63SYohann //  2D  //
140c532df63SYohann //////////
141c532df63SYohann 
142c532df63SYohann inline __device__ void readDofs2d(const int elem, const int tidx,
143c532df63SYohann                                   const int tidy, const int comp,
144c532df63SYohann                                   const int nelem, const CeedScalar *d_U, CeedScalar &U) {
145c532df63SYohann   U = (tidx<P1D
146c532df63SYohann        && tidy<P1D) ? d_U[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D ] :
147c532df63SYohann       0.0;
148c532df63SYohann }
149c532df63SYohann 
150c532df63SYohann inline __device__ void writeDofs2d(const int elem, const int tidx,
151c532df63SYohann                                    const int tidy, const int comp,
152c532df63SYohann                                    const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
153c532df63SYohann   if (tidx<P1D && tidy<P1D) {
154c532df63SYohann     d_V[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D ] = r_V;
155c532df63SYohann   }
156c532df63SYohann }
157c532df63SYohann 
158c532df63SYohann inline __device__ void readQuads2d(const int elem, const int tidx,
159c532df63SYohann                                    const int tidy, const int comp,
160c532df63SYohann                                    const int dim, const int nelem, const CeedScalar *d_U, CeedScalar &U ) {
161c532df63SYohann   U = d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
162c532df63SYohann                dim*BASIS_NCOMP*nelem*Q1D*Q1D];
163c532df63SYohann }
164c532df63SYohann 
165c532df63SYohann inline __device__ void writeQuads2d(const int elem, const int tidx,
166c532df63SYohann                                     const int tidy, const int comp,
167c532df63SYohann                                     const int dim, const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
168c532df63SYohann   d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
169c532df63SYohann            dim*BASIS_NCOMP*nelem*Q1D*Q1D ] = r_V;
170c532df63SYohann }
171c532df63SYohann 
172c532df63SYohann inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
173c532df63SYohann                                    const int tidy,
174c532df63SYohann                                    const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
175c532df63SYohann   slice[tidx+tidy*Q1D] = U;
176c532df63SYohann   __syncthreads();
177c532df63SYohann   V = 0.0;
178c532df63SYohann   for (int i = 0; i < P1D; ++i) {
179c532df63SYohann     V += B[i + tidx*P1D] * slice[i + tidy*Q1D];//contract x direction
180c532df63SYohann   }
181c532df63SYohann   __syncthreads();
182c532df63SYohann }
183c532df63SYohann 
184c532df63SYohann inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
185c532df63SYohann                                    const int tidy,
186c532df63SYohann                                    const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
187c532df63SYohann   slice[tidx+tidy*Q1D] = U;
188c532df63SYohann   __syncthreads();
189c532df63SYohann   V = 0.0;
190c532df63SYohann   for (int i = 0; i < P1D; ++i) {
191c532df63SYohann     V += B[i + tidy*P1D] * slice[tidx + i*Q1D];//contract y direction
192c532df63SYohann   }
193c532df63SYohann   __syncthreads();
194c532df63SYohann }
195c532df63SYohann 
196c532df63SYohann inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
197c532df63SYohann     const int tidy,
198c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
199c532df63SYohann   slice[tidx+tidy*Q1D] = U;
200c532df63SYohann   __syncthreads();
201c532df63SYohann   V = 0.0;
202c532df63SYohann   if (tidy<P1D) {
203c532df63SYohann     for (int i = 0; i < Q1D; ++i) {
204c532df63SYohann       V += B[tidy + i*P1D] * slice[tidx + i*Q1D];//contract y direction
205c532df63SYohann     }
206c532df63SYohann   }
207c532df63SYohann   __syncthreads();
208c532df63SYohann }
209c532df63SYohann 
210c532df63SYohann inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
211c532df63SYohann     const int tidy,
212c532df63SYohann     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
213c532df63SYohann   slice[tidx+tidy*Q1D] = U;
214c532df63SYohann   __syncthreads();
215c532df63SYohann   V = 0.0;
216c532df63SYohann   if (tidx<P1D) {
217c532df63SYohann     for (int i = 0; i < Q1D; ++i) {
218c532df63SYohann       V += B[tidx + i*P1D] * slice[i + tidy*Q1D];//contract x direction
219c532df63SYohann     }
220c532df63SYohann   }
221c532df63SYohann   __syncthreads();
222c532df63SYohann }
223c532df63SYohann 
224c532df63SYohann inline __device__ void interp2d(const CeedInt nelem, const int transpose,
225c532df63SYohann                                 const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
226c532df63SYohann                                 CeedScalar *__restrict__ d_V,
227c532df63SYohann                                 CeedScalar *slice) {
228c532df63SYohann   CeedScalar r_V;
229c532df63SYohann   CeedScalar r_t;
230c532df63SYohann 
231c532df63SYohann   const int tidx = threadIdx.x;
232c532df63SYohann   const int tidy = threadIdx.y;
233c532df63SYohann 
234c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
235c532df63SYohann        elem += gridDim.x*blockDim.z) {
236c532df63SYohann     for(int comp=0; comp<BASIS_NCOMP; comp++) {
237c532df63SYohann       r_V = 0.0;
238c532df63SYohann       r_t = 0.0;
239c532df63SYohann       if(!transpose) {
240c532df63SYohann         readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
241c532df63SYohann         ContractX2d(slice, tidx, tidy, r_V, c_B, r_t);
242c532df63SYohann         ContractY2d(slice, tidx, tidy, r_t, c_B, r_V);
243c532df63SYohann         writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
244c532df63SYohann       } else {
245c532df63SYohann         readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
246c532df63SYohann         ContractTransposeY2d(slice, tidx, tidy, r_V, c_B, r_t);
247c532df63SYohann         ContractTransposeX2d(slice, tidx, tidy, r_t, c_B, r_V);
248c532df63SYohann         writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
249c532df63SYohann       }
250c532df63SYohann     }
251c532df63SYohann   }
252c532df63SYohann }
253c532df63SYohann 
254c532df63SYohann inline __device__ void grad2d(const CeedInt nelem, const int transpose,
255c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
256c532df63SYohann                               const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V,
257c532df63SYohann                               CeedScalar *slice) {
258c532df63SYohann   CeedScalar r_U;
259c532df63SYohann   CeedScalar r_V;
260c532df63SYohann   CeedScalar r_t;
261c532df63SYohann 
262c532df63SYohann   const int tidx = threadIdx.x;
263c532df63SYohann   const int tidy = threadIdx.y;
264c532df63SYohann   int dim;
265c532df63SYohann 
266c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
267c532df63SYohann        elem += gridDim.x*blockDim.z) {
268c532df63SYohann     for(int comp=0; comp<BASIS_NCOMP; comp++) {
269c532df63SYohann       if(!transpose) {
270c532df63SYohann         readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
271c532df63SYohann         ContractX2d(slice, tidx, tidy, r_U, c_G, r_t);
272c532df63SYohann         ContractY2d(slice, tidx, tidy, r_t, c_B, r_V);
273c532df63SYohann         dim = 0;
274c532df63SYohann         writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
275c532df63SYohann         ContractX2d(slice, tidx, tidy, r_U, c_B, r_t);
276c532df63SYohann         ContractY2d(slice, tidx, tidy, r_t, c_G, r_V);
277c532df63SYohann         dim = 1;
278c532df63SYohann         writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
279c532df63SYohann       } else {
280c532df63SYohann         dim = 0;
281c532df63SYohann         readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
282c532df63SYohann         ContractTransposeY2d(slice, tidx, tidy, r_U, c_B, r_t);
283c532df63SYohann         ContractTransposeX2d(slice, tidx, tidy, r_t, c_G, r_V);
284c532df63SYohann         dim = 1;
285c532df63SYohann         readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
286c532df63SYohann         ContractTransposeY2d(slice, tidx, tidy, r_U, c_G, r_t);
287c532df63SYohann         ContractTransposeX2d(slice, tidx, tidy, r_t, c_B, r_U);
288c532df63SYohann         r_V+=r_U;
289c532df63SYohann         writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
290c532df63SYohann       }
291c532df63SYohann     }
292c532df63SYohann   }
293c532df63SYohann }
294c532df63SYohann //////////
295c532df63SYohann //  3D  //
296c532df63SYohann //////////
297c532df63SYohann 
298c532df63SYohann inline __device__ void readDofs3d(const int elem, const int tidx,
299c532df63SYohann                                   const int tidy, const int comp,
300c532df63SYohann                                   const int nelem, const CeedScalar *d_U, CeedScalar *r_U) {
301c532df63SYohann   for (int i = 0; i < P1D; i++)
302c532df63SYohann     r_U[i] = (tidx<P1D
303c532df63SYohann               && tidy<P1D) ? d_U[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
304c532df63SYohann                                       elem*BASIS_NCOMP*P1D*P1D*P1D ] : 0.0;
305c532df63SYohann   for (int i = P1D; i < Q1D; i++)
306c532df63SYohann     r_U[i] = 0.0;
307c532df63SYohann }
308c532df63SYohann 
309c532df63SYohann inline __device__ void readQuads3d(const int elem, const int tidx,
310c532df63SYohann                                    const int tidy, const int comp,
311c532df63SYohann                                    const int dim, const int nelem, const CeedScalar *d_U, CeedScalar *r_U) {
312c532df63SYohann   for (int i = 0; i < Q1D; i++)
313c532df63SYohann     r_U[i] = d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
314c532df63SYohann                  comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D];
315c532df63SYohann }
316c532df63SYohann 
317c532df63SYohann inline __device__ void writeDofs3d(const int elem, const int tidx,
318c532df63SYohann                                    const int tidy, const int comp,
319c532df63SYohann                                    const int nelem, const CeedScalar *r_V, CeedScalar *d_V) {
320c532df63SYohann   if (tidx<P1D && tidy<P1D) {
321c532df63SYohann     for (int i = 0; i < P1D; i++)
322c532df63SYohann       d_V[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
323c532df63SYohann           elem*BASIS_NCOMP*P1D*P1D*P1D ] = r_V[i];
324c532df63SYohann   }
325c532df63SYohann }
326c532df63SYohann 
327c532df63SYohann inline __device__ void writeQuads3d(const int elem, const int tidx,
328c532df63SYohann                                     const int tidy, const int comp,
329c532df63SYohann                                     const int dim, const int nelem, const CeedScalar *r_V, CeedScalar *d_V) {
330c532df63SYohann   for (int i = 0; i < Q1D; i++)
331c532df63SYohann     d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
332c532df63SYohann         dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D ] = r_V[i];
333c532df63SYohann }
334c532df63SYohann 
335c532df63SYohann inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
336c532df63SYohann                                    const int tidy,
337c532df63SYohann                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
338c532df63SYohann   for (int k = 0; k < P1D; ++k) {
339c532df63SYohann     slice[tidx+tidy*Q1D] = U[k];
340c532df63SYohann     __syncthreads();
341c532df63SYohann     V[k] = 0.0;
342c532df63SYohann     for (int i = 0; i < P1D; ++i) {
343c532df63SYohann       V[k] += B[i + tidx*P1D] * slice[i + tidy*Q1D];//contract x direction
344c532df63SYohann     }
345c532df63SYohann     __syncthreads();
346c532df63SYohann   }
347c532df63SYohann }
348c532df63SYohann 
349c532df63SYohann inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
350c532df63SYohann                                    const int tidy,
351c532df63SYohann                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
352c532df63SYohann   for (int k = 0; k < P1D; ++k) {
353c532df63SYohann     slice[tidx+tidy*Q1D] = U[k];
354c532df63SYohann     __syncthreads();
355c532df63SYohann     V[k] = 0.0;
356c532df63SYohann     for (int i = 0; i < P1D; ++i) {
357c532df63SYohann       V[k] += B[i + tidy*P1D] * slice[tidx + i*Q1D];//contract y direction
358c532df63SYohann     }
359c532df63SYohann     __syncthreads();
360c532df63SYohann   }
361c532df63SYohann }
362c532df63SYohann 
363c532df63SYohann inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
364c532df63SYohann                                    const int tidy,
365c532df63SYohann                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
366c532df63SYohann   for (int k = 0; k < Q1D; ++k) {
367c532df63SYohann     V[k] = 0.0;
368c532df63SYohann     for (int i = 0; i < P1D; ++i) {
369c532df63SYohann       V[k] += B[i + k*P1D] * U[i];//contract z direction
370c532df63SYohann     }
371c532df63SYohann   }
372c532df63SYohann }
373c532df63SYohann 
374c532df63SYohann inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
375c532df63SYohann     const int tidy,
376c532df63SYohann     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
377c532df63SYohann   for (int k = 0; k < Q1D; ++k) {
378c532df63SYohann     V[k] = 0.0;
379c532df63SYohann     if (k<P1D) {
380c532df63SYohann       for (int i = 0; i < Q1D; ++i) {
381c532df63SYohann         V[k] += B[k + i*P1D] * U[i];//contract z direction
382c532df63SYohann       }
383c532df63SYohann     }
384c532df63SYohann   }
385c532df63SYohann }
386c532df63SYohann 
387c532df63SYohann inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
388c532df63SYohann     const int tidy,
389c532df63SYohann     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
390c532df63SYohann   for (int k = 0; k < P1D; ++k) {
391c532df63SYohann     slice[tidx+tidy*Q1D] = U[k];
392c532df63SYohann     __syncthreads();
393c532df63SYohann     V[k] = 0.0;
394c532df63SYohann     if (tidy<P1D) {
395c532df63SYohann       for (int i = 0; i < Q1D; ++i) {
396c532df63SYohann         V[k] += B[tidy + i*P1D] * slice[tidx + i*Q1D];//contract y direction
397c532df63SYohann       }
398c532df63SYohann     }
399c532df63SYohann     __syncthreads();
400c532df63SYohann   }
401c532df63SYohann }
402c532df63SYohann 
403c532df63SYohann inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
404c532df63SYohann     const int tidy,
405c532df63SYohann     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
406c532df63SYohann   for (int k = 0; k < P1D; ++k) {
407c532df63SYohann     slice[tidx+tidy*Q1D] = U[k];
408c532df63SYohann     __syncthreads();
409c532df63SYohann     V[k] = 0.0;
410c532df63SYohann     if (tidx<P1D) {
411c532df63SYohann       for (int i = 0; i < Q1D; ++i) {
412c532df63SYohann         V[k] += B[tidx + i*P1D] * slice[i + tidy*Q1D];//contract x direction
413c532df63SYohann       }
414c532df63SYohann     }
415c532df63SYohann     __syncthreads();
416c532df63SYohann   }
417c532df63SYohann }
418c532df63SYohann 
419c532df63SYohann inline __device__ void interp3d(const CeedInt nelem, const int transpose,
420c532df63SYohann                                 const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
421c532df63SYohann                                 CeedScalar *__restrict__ d_V,
422c532df63SYohann                                 CeedScalar *slice) {
423c532df63SYohann   CeedScalar r_V[Q1D];
424c532df63SYohann   CeedScalar r_t[Q1D];
425c532df63SYohann 
426c532df63SYohann   const int tidx = threadIdx.x;
427c532df63SYohann   const int tidy = threadIdx.y;
428c532df63SYohann 
429c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
430c532df63SYohann        elem += gridDim.x*blockDim.z) {
431c532df63SYohann     for(int comp=0; comp<BASIS_NCOMP; comp++) {
432c532df63SYohann       for (int i = 0; i < Q1D; ++i) {
433c532df63SYohann         r_V[i] = 0.0;
434c532df63SYohann         r_t[i] = 0.0;
435c532df63SYohann       }
436c532df63SYohann       if(!transpose) {
437c532df63SYohann         readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
438c532df63SYohann         ContractX3d(slice, tidx, tidy, r_V, c_B, r_t);
439c532df63SYohann         ContractY3d(slice, tidx, tidy, r_t, c_B, r_V);
440c532df63SYohann         ContractZ3d(slice, tidx, tidy, r_V, c_B, r_t);
441c532df63SYohann         writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
442c532df63SYohann       } else {
443c532df63SYohann         readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
444c532df63SYohann         ContractTransposeZ3d(slice, tidx, tidy, r_V, c_B, r_t);
445c532df63SYohann         ContractTransposeY3d(slice, tidx, tidy, r_t, c_B, r_V);
446c532df63SYohann         ContractTransposeX3d(slice, tidx, tidy, r_V, c_B, r_t);
447c532df63SYohann         writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
448c532df63SYohann       }
449c532df63SYohann     }
450c532df63SYohann   }
451c532df63SYohann }
452c532df63SYohann 
453c532df63SYohann inline __device__ void grad3d(const CeedInt nelem, const int transpose,
454c532df63SYohann                               const CeedScalar *c_B, const CeedScalar *c_G,
455c532df63SYohann                               const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V,
456c532df63SYohann                               CeedScalar *slice) {
457c532df63SYohann   //use P1D for one of these
458c532df63SYohann   CeedScalar r_U[Q1D];
459c532df63SYohann   CeedScalar r_V[Q1D];
460c532df63SYohann   CeedScalar r_t[Q1D];
461c532df63SYohann 
462c532df63SYohann   const int tidx = threadIdx.x;
463c532df63SYohann   const int tidy = threadIdx.y;
464c532df63SYohann   int dim;
465c532df63SYohann 
466c532df63SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
467c532df63SYohann        elem += gridDim.x*blockDim.z) {
468c532df63SYohann     for(int comp=0; comp<BASIS_NCOMP; comp++) {
469c532df63SYohann       if(!transpose) {
470c532df63SYohann         readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
471c532df63SYohann         ContractX3d(slice, tidx, tidy, r_U, c_G, r_V);
472c532df63SYohann         ContractY3d(slice, tidx, tidy, r_V, c_B, r_t);
473c532df63SYohann         ContractZ3d(slice, tidx, tidy, r_t, c_B, r_V);
474c532df63SYohann         dim = 0;
475c532df63SYohann         writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
476c532df63SYohann         ContractX3d(slice, tidx, tidy, r_U, c_B, r_V);
477c532df63SYohann         ContractY3d(slice, tidx, tidy, r_V, c_G, r_t);
478c532df63SYohann         ContractZ3d(slice, tidx, tidy, r_t, c_B, r_V);
479c532df63SYohann         dim = 1;
480c532df63SYohann         writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
481c532df63SYohann         ContractX3d(slice, tidx, tidy, r_U, c_B, r_V);
482c532df63SYohann         ContractY3d(slice, tidx, tidy, r_V, c_B, r_t);
483c532df63SYohann         ContractZ3d(slice, tidx, tidy, r_t, c_G, r_V);
484c532df63SYohann         dim = 2;
485c532df63SYohann         writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
486c532df63SYohann       } else {
487c532df63SYohann         dim = 0;
488c532df63SYohann         readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
489c532df63SYohann         ContractTransposeZ3d(slice, tidx, tidy, r_U, c_B, r_t);
490c532df63SYohann         ContractTransposeY3d(slice, tidx, tidy, r_t, c_B, r_U);
491c532df63SYohann         ContractTransposeX3d(slice, tidx, tidy, r_U, c_G, r_V);
492c532df63SYohann         dim = 1;
493c532df63SYohann         readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
494c532df63SYohann         ContractTransposeZ3d(slice, tidx, tidy, r_U, c_B, r_t);
495c532df63SYohann         ContractTransposeY3d(slice, tidx, tidy, r_t, c_G, r_U);
496c532df63SYohann         ContractTransposeX3d(slice, tidx, tidy, r_U, c_B, r_t);
497c532df63SYohann         add(r_V, r_t);
498c532df63SYohann         dim = 2;
499c532df63SYohann         readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
500c532df63SYohann         ContractTransposeZ3d(slice, tidx, tidy, r_U, c_G, r_t);
501c532df63SYohann         ContractTransposeY3d(slice, tidx, tidy, r_t, c_B, r_U);
502c532df63SYohann         ContractTransposeX3d(slice, tidx, tidy, r_U, c_B, r_t);
503c532df63SYohann         add(r_V, r_t);
504c532df63SYohann         writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
505c532df63SYohann       }
506c532df63SYohann     }
507c532df63SYohann   }
508c532df63SYohann }
509c532df63SYohann 
510c532df63SYohann /////////////
511c532df63SYohann // Kernels //
512c532df63SYohann /////////////
513c532df63SYohann extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
514c532df63SYohann                                   const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
515c532df63SYohann                                   CeedScalar *__restrict__ d_V) {
516*074be161SYohann Dudouit   // __shared__ double slice[Q1D*Q1D];//Fix me if ElemPerBlock>1
517*074be161SYohann Dudouit   extern __shared__ double slice[];
518c532df63SYohann   if (BASIS_DIM==1) {
519c532df63SYohann     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
520c532df63SYohann   } else if (BASIS_DIM==2) {
521c532df63SYohann     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
522c532df63SYohann   } else if (BASIS_DIM==3) {
523c532df63SYohann     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
524c532df63SYohann   }
525c532df63SYohann }
526c532df63SYohann 
527c532df63SYohann extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
528c532df63SYohann                                 const CeedScalar *c_B, const CeedScalar *c_G,
529c532df63SYohann                                 const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) {
530*074be161SYohann Dudouit   // __shared__ double slice[Q1D*Q1D];//Fix me if ElemPerBlock>1
531*074be161SYohann Dudouit   extern __shared__ double slice[];
532c532df63SYohann   if (BASIS_DIM==1) {
533c532df63SYohann     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
534c532df63SYohann   } else if (BASIS_DIM==2) {
535c532df63SYohann     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
536c532df63SYohann   } else if (BASIS_DIM==3) {
537c532df63SYohann     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
538c532df63SYohann   }
539c532df63SYohann }
540c532df63SYohann 
541c532df63SYohann /////////////
542c532df63SYohann // Weights //
543c532df63SYohann /////////////
544c532df63SYohann __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
545c532df63SYohann                          CeedScalar *w) {
546*074be161SYohann Dudouit   const int tid = threadIdx.x;
547*074be161SYohann Dudouit   const CeedScalar weight = qweight1d[tid];
548*074be161SYohann Dudouit   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
549*074be161SYohann Dudouit        elem += gridDim.x*blockDim.y) {
550*074be161SYohann Dudouit     const int ind = elem*Q1D + tid;
551*074be161SYohann Dudouit     w[ind] = weight;
552c532df63SYohann   }
553c532df63SYohann }
554c532df63SYohann 
555c532df63SYohann __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
556c532df63SYohann                          CeedScalar *w) {
557*074be161SYohann Dudouit   const int i = threadIdx.x;
558*074be161SYohann Dudouit   const int j = threadIdx.y;
559*074be161SYohann Dudouit   const CeedScalar weight = qweight1d[i]*qweight1d[j];
560*074be161SYohann Dudouit   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
561*074be161SYohann Dudouit        elem += gridDim.x*blockDim.z) {
562*074be161SYohann Dudouit     const int ind = elem*Q1D*Q1D + i + j*Q1D;
563*074be161SYohann Dudouit     w[ind] = weight;
564c532df63SYohann   }
565c532df63SYohann }
566c532df63SYohann 
567c532df63SYohann __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
568c532df63SYohann                          CeedScalar *w) {
569*074be161SYohann Dudouit   const int i = threadIdx.x;
570*074be161SYohann Dudouit   const int j = threadIdx.y;
571*074be161SYohann Dudouit   const int k = threadIdx.z;
572*074be161SYohann Dudouit   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
573*074be161SYohann Dudouit   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
574*074be161SYohann Dudouit     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
575*074be161SYohann Dudouit     w[ind] = weight;
576c532df63SYohann   }
577c532df63SYohann }
578c532df63SYohann 
579c532df63SYohann extern "C" __global__ void weight(const CeedInt nelem,
580c532df63SYohann                                   const CeedScalar *__restrict__ qweight1d, CeedScalar *__restrict__ v) {
581c532df63SYohann   if (BASIS_DIM==1) {
582c532df63SYohann     weight1d(nelem, qweight1d, v);
583c532df63SYohann   } else if (BASIS_DIM==2) {
584c532df63SYohann     weight2d(nelem, qweight1d, v);
585c532df63SYohann   } else if (BASIS_DIM==3) {
586c532df63SYohann     weight3d(nelem, qweight1d, v);
587c532df63SYohann   }
588c532df63SYohann }
589c532df63SYohann 
590c532df63SYohann                                    );
591c532df63SYohann 
592c532df63SYohann int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
593c532df63SYohann                        CeedScalar **c_B);
594c532df63SYohann int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
595c532df63SYohann                            CeedInt Q1d, CeedScalar **c_B_ptr, CeedScalar **c_G_ptr);
596c532df63SYohann 
597c532df63SYohann int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem,
598c532df63SYohann                                      CeedTransposeMode tmode,
599c532df63SYohann                                      CeedEvalMode emode, CeedVector u, CeedVector v) {
600c532df63SYohann   int ierr;
601c532df63SYohann   Ceed ceed;
602c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
603c532df63SYohann   Ceed_Cuda_shared *ceed_Cuda;
604c532df63SYohann   CeedGetData(ceed, (void *) &ceed_Cuda); CeedChk(ierr);
605c532df63SYohann   CeedBasis_Cuda_shared *data;
606c532df63SYohann   CeedBasisGetData(basis, (void *)&data); CeedChk(ierr);
607c532df63SYohann   const CeedInt transpose = tmode == CEED_TRANSPOSE;
608*074be161SYohann Dudouit   CeedInt dim;
609*074be161SYohann Dudouit   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
610c532df63SYohann   // const int optElems[7] = {0,32,8,3,2,1,8};
611*074be161SYohann Dudouit   CeedInt elemsPerBlock = 1;//basis->Q1d < 7 ? optElems[basis->Q1d] : 1;
612*074be161SYohann Dudouit   CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)?
613c532df63SYohann                                      1 : 0 );
614c532df63SYohann 
615c532df63SYohann   const CeedScalar *d_u;
616c532df63SYohann   CeedScalar *d_v;
617c532df63SYohann   if(emode!=CEED_EVAL_WEIGHT) {
618c532df63SYohann     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
619c532df63SYohann   }
620c532df63SYohann   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
621c532df63SYohann 
622c532df63SYohann   if (tmode == CEED_TRANSPOSE) {
623c532df63SYohann     CeedInt length;
624c532df63SYohann     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
625c532df63SYohann     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
626c532df63SYohann   }
627c532df63SYohann   if (emode == CEED_EVAL_INTERP) {
628c532df63SYohann     //TODO: check performance difference between c_B and d_B
629c532df63SYohann     CeedInt P1d, Q1d;
630c532df63SYohann     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
631c532df63SYohann     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
632*074be161SYohann Dudouit     // if (ceed_Cuda->Q1d != Q1d || ceed_Cuda->P1d != P1d)
633*074be161SYohann Dudouit     // {
634*074be161SYohann Dudouit     //   ceed_Cuda->Q1d = Q1d;
635*074be161SYohann Dudouit     //   ceed_Cuda->P1d = P1d;
636*074be161SYohann Dudouit     //   ceed_Cuda->grad = false;
637c532df63SYohann       ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
638c532df63SYohann       CeedChk(ierr);
639*074be161SYohann Dudouit     // }
640c532df63SYohann     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, &d_u, &d_v};
641*074be161SYohann Dudouit     // void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->d_interp1d, &d_u, &d_v};
642*074be161SYohann Dudouit     if (dim==1)
643*074be161SYohann Dudouit     {
644*074be161SYohann Dudouit       CeedInt sharedMem = Q1d*sizeof(CeedScalar);
645*074be161SYohann Dudouit       ierr = run_kernel_dim_shared(ceed, data->interp, grid, Q1d, 1, elemsPerBlock, sharedMem,
646c532df63SYohann                             interpargs);
647c532df63SYohann       CeedChk(ierr);
648*074be161SYohann Dudouit     } else if (dim==2) {
649*074be161SYohann Dudouit       CeedInt sharedMem = Q1d*Q1d*sizeof(CeedScalar);
650*074be161SYohann Dudouit       ierr = run_kernel_dim_shared(ceed, data->interp, grid, Q1d, Q1d, elemsPerBlock, sharedMem,
651*074be161SYohann Dudouit                             interpargs);
652*074be161SYohann Dudouit       CeedChk(ierr);
653*074be161SYohann Dudouit     } else if (dim==3) {
654*074be161SYohann Dudouit       CeedInt sharedMem = Q1d*Q1d*Q1d*sizeof(CeedScalar);
655*074be161SYohann Dudouit       ierr = run_kernel_dim_shared(ceed, data->interp, grid, Q1d, Q1d, elemsPerBlock, sharedMem,
656*074be161SYohann Dudouit                             interpargs);
657*074be161SYohann Dudouit       CeedChk(ierr);
658*074be161SYohann Dudouit     }
659c532df63SYohann   } else if (emode == CEED_EVAL_GRAD) {
660c532df63SYohann     CeedInt P1d, Q1d;
661c532df63SYohann     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
662c532df63SYohann     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
663*074be161SYohann Dudouit     // if (ceed_Cuda->Q1d != Q1d || ceed_Cuda->P1d != P1d || !data->grad)
664*074be161SYohann Dudouit     // {
665*074be161SYohann Dudouit     //   ceed_Cuda->Q1d = Q1d;
666*074be161SYohann Dudouit     //   ceed_Cuda->P1d = P1d;
667*074be161SYohann Dudouit     //   ceed_Cuda->grad = true;
668c532df63SYohann       ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
669c532df63SYohann                                     Q1d, &data->c_B, &data->c_G);
670c532df63SYohann       CeedChk(ierr);
671*074be161SYohann Dudouit     // }
672c532df63SYohann     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, &data->c_G, &d_u, &d_v};
673*074be161SYohann Dudouit     // void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->d_interp1d, &data->d_grad1d, &d_u, &d_v};
674*074be161SYohann Dudouit     if (dim==1)
675*074be161SYohann Dudouit     {
676*074be161SYohann Dudouit       CeedInt sharedMem = Q1d*sizeof(CeedScalar);
677*074be161SYohann Dudouit       ierr = run_kernel_dim_shared(ceed, data->grad, grid, Q1d, 1, elemsPerBlock, sharedMem,
678c532df63SYohann                           gradargs);
679c532df63SYohann       CeedChk(ierr);
680*074be161SYohann Dudouit     } else if (dim==2) {
681*074be161SYohann Dudouit       CeedInt sharedMem = Q1d*Q1d*sizeof(CeedScalar);
682*074be161SYohann Dudouit       ierr = run_kernel_dim_shared(ceed, data->grad, grid, Q1d, Q1d, elemsPerBlock, sharedMem,
683*074be161SYohann Dudouit                           gradargs);
684*074be161SYohann Dudouit       CeedChk(ierr);
685*074be161SYohann Dudouit     } else if (dim==3) {
686*074be161SYohann Dudouit       CeedInt sharedMem = Q1d*Q1d*Q1d*sizeof(CeedScalar);
687*074be161SYohann Dudouit       ierr = run_kernel_dim_shared(ceed, data->grad, grid, Q1d, Q1d, elemsPerBlock, sharedMem,
688*074be161SYohann Dudouit                           gradargs);
689*074be161SYohann Dudouit       CeedChk(ierr);
690*074be161SYohann Dudouit     }
691c532df63SYohann   } else if (emode == CEED_EVAL_WEIGHT) {
692*074be161SYohann Dudouit     CeedInt Q1d;
693*074be161SYohann Dudouit     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
694c532df63SYohann     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
695*074be161SYohann Dudouit     if(dim==1){
696*074be161SYohann Dudouit       const CeedInt elemsPerBlock = 32/Q1d;
697*074be161SYohann Dudouit       const CeedInt gridsize = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
698*074be161SYohann Dudouit       ierr = run_kernel_dim(ceed, data->weight, gridsize, Q1d, elemsPerBlock, 1, weightargs);
699*074be161SYohann Dudouit     } else if(dim==2) {
700*074be161SYohann Dudouit       const CeedInt elemsPerBlock = 32/(Q1d*Q1d);
701*074be161SYohann Dudouit       const CeedInt gridsize = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
702*074be161SYohann Dudouit       ierr = run_kernel_dim(ceed, data->weight, gridsize, Q1d, Q1d, elemsPerBlock, weightargs);
703*074be161SYohann Dudouit     } else if(dim==3) {
704*074be161SYohann Dudouit       const CeedInt gridsize = nelem;
705*074be161SYohann Dudouit       ierr = run_kernel_dim(ceed, data->weight, gridsize, Q1d, Q1d, Q1d, weightargs);
706*074be161SYohann Dudouit     }
707c532df63SYohann   }
708c532df63SYohann 
709c532df63SYohann   if(emode!=CEED_EVAL_WEIGHT) {
710c532df63SYohann     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
711c532df63SYohann   }
712c532df63SYohann   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
713c532df63SYohann 
714c532df63SYohann   return 0;
715c532df63SYohann }
716c532df63SYohann 
717c532df63SYohann static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
718c532df63SYohann   int ierr;
719c532df63SYohann   Ceed ceed;
720c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
721c532df63SYohann 
722c532df63SYohann   CeedBasis_Cuda_shared *data;
723c532df63SYohann   ierr = CeedBasisGetData(basis, (void *) &data); CeedChk(ierr);
724c532df63SYohann 
725c532df63SYohann   CeedChk_Cu(ceed, cuModuleUnload(data->module));
726c532df63SYohann 
727c532df63SYohann   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr);
728c532df63SYohann   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr);
729c532df63SYohann   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr);
730c532df63SYohann 
731c532df63SYohann   ierr = CeedFree(&data); CeedChk(ierr);
732c532df63SYohann 
733c532df63SYohann   return 0;
734c532df63SYohann }
735c532df63SYohann 
736c532df63SYohann int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
737c532df63SYohann                                         const CeedScalar *interp1d,
738c532df63SYohann                                         const CeedScalar *grad1d,
739c532df63SYohann                                         const CeedScalar *qref1d,
740c532df63SYohann                                         const CeedScalar *qweight1d,
741c532df63SYohann                                         CeedBasis basis) {
742c532df63SYohann   int ierr;
743c532df63SYohann   Ceed ceed;
744c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
745c532df63SYohann   CeedBasis_Cuda_shared *data;
746c532df63SYohann   ierr = CeedCalloc(1, &data); CeedChk(ierr);
747c532df63SYohann 
748c532df63SYohann   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
749c532df63SYohann   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr);
750c532df63SYohann   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
751c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
752c532df63SYohann 
753c532df63SYohann   const CeedInt iBytes = qBytes * P1d;
754c532df63SYohann   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr);
755c532df63SYohann   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
756c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
757c532df63SYohann 
758c532df63SYohann   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr);
759c532df63SYohann   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
760c532df63SYohann                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
761c532df63SYohann 
762c532df63SYohann   CeedInt ncomp;
763c532df63SYohann   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
764c532df63SYohann   ierr = compile(ceed, kernelsShared, &data->module, 7,
765c532df63SYohann                  "Q1D", Q1d,
766c532df63SYohann                  "P1D", P1d,
767c532df63SYohann                  "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
768c532df63SYohann                      Q1d : P1d, dim),
769c532df63SYohann                  "BASIS_DIM", dim,
770c532df63SYohann                  "BASIS_NCOMP", ncomp,
771c532df63SYohann                  "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
772c532df63SYohann                  "BASIS_NQPT", CeedIntPow(Q1d, dim)
773c532df63SYohann                 ); CeedChk(ierr);
774c532df63SYohann   ierr = get_kernel(ceed, data->module, "interp", &data->interp);
775c532df63SYohann   CeedChk(ierr);
776c532df63SYohann   ierr = get_kernel(ceed, data->module, "grad", &data->grad);
777c532df63SYohann   CeedChk(ierr);
778c532df63SYohann   ierr = get_kernel(ceed, data->module, "weight", &data->weight);
779c532df63SYohann   CeedChk(ierr);
780c532df63SYohann 
781c532df63SYohann   ierr = CeedBasisSetData(basis, (void *)&data);
782c532df63SYohann   CeedChk(ierr);
783c532df63SYohann   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
784c532df63SYohann                                 CeedBasisApplyTensor_Cuda_shared);
785c532df63SYohann   CeedChk(ierr);
786c532df63SYohann   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
787c532df63SYohann                                 CeedBasisDestroy_Cuda_shared);
788c532df63SYohann   CeedChk(ierr);
789c532df63SYohann   return 0;
790c532df63SYohann }
791c532df63SYohann 
792c532df63SYohann int CeedBasisCreateH1_Cuda_shared(CeedElemTopology topo, CeedInt dim,
793c532df63SYohann                                   CeedInt ndof, CeedInt nqpts,
794c532df63SYohann                                   const CeedScalar *interp,
795c532df63SYohann                                   const CeedScalar *grad,
796c532df63SYohann                                   const CeedScalar *qref,
797c532df63SYohann                                   const CeedScalar *qweight,
798c532df63SYohann                                   CeedBasis basis) {
799c532df63SYohann   int ierr;
800c532df63SYohann   Ceed ceed;
801c532df63SYohann   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
802c532df63SYohann   return CeedError(ceed, 1, "Backend does not implement generic H1 basis");
803c532df63SYohann }
804