xref: /libCEED/backends/cuda-shared/ceed-cuda-shared-basis.c (revision 717ff8a3859920e3e4f3be4fafa268d53617d237)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed-backend.h>
18 #include <ceed.h>
19 #include "ceed-cuda-shared.h"
20 #include "../cuda/ceed-cuda.h"
21 
22 //*********************
23 // shared mem kernels
24 static const char *kernelsShared = QUOTE(
25 
26 inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
27   for (int i = 0; i < Q1D; i++)
28     r_V[i] += r_U[i];
29 }
30 
31 //////////
32 //  1D  //
33 //////////
34 
35 inline __device__ void readDofs1d(const int elem, const int tidx,
36                                   const int tidy, const int comp,
37                                   const int nelem, const CeedScalar *d_U, CeedScalar *slice) {
38   for (int i = 0; i < P1D; i++)
39     slice[i] = d_U[i + comp*P1D + elem*BASIS_NCOMP*P1D];
40   for (int i = P1D; i < Q1D; i++)
41     slice[i] = 0.0;
42 }
43 
44 inline __device__ void writeDofs1d(const int elem, const int tidx,
45                                    const int tidy, const int comp,
46                                    const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
47   if (tidx<P1D) {
48     d_V[tidx + comp*P1D + elem*BASIS_NCOMP*P1D] = r_V;
49   }
50 }
51 
52 inline __device__ void readQuads1d(const int elem, const int tidx,
53                                    const int tidy, const int comp,
54                                    const int dim, const int nelem, const CeedScalar *d_U, CeedScalar *slice) {
55   for (int i = 0; i < Q1D; i++)
56     slice[i] = d_U[i + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D];
57 }
58 
59 inline __device__ void writeQuads1d(const int elem, const int tidx,
60                                     const int tidy, const int comp,
61                                     const int dim, const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
62   d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
63 }
64 
65 inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
66                                    const int tidy,
67                                    const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
68   V = 0.0;
69   for (int i = 0; i < P1D; ++i) {
70     V += B[i + tidx*P1D] * slice[i];//contract x direction
71   }
72 }
73 
74 inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
75     const int tidy,
76     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
77   V = 0.0;
78   for (int i = 0; i < Q1D; ++i) {
79     V += B[tidx + i*P1D] * slice[i];//contract x direction
80   }
81 }
82 
83 inline __device__ void interp1d(const CeedInt nelem, const int transpose,
84                                 const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
85                                 CeedScalar *__restrict__ d_V,
86                                 CeedScalar *slice) {
87   CeedScalar r_V;
88   CeedScalar r_t;
89 
90   const int tidx = threadIdx.x;
91   const int tidy = threadIdx.y;
92 
93 
94   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
95        elem += gridDim.x*blockDim.z) {
96     for(int comp=0; comp<BASIS_NCOMP; comp++) {
97       if(!transpose) {
98         readDofs1d(elem, tidx, tidy, comp, nelem, d_U, slice);
99         ContractX1d(slice, tidx, tidy, r_t, c_B, r_V);
100         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
101       } else {
102         readQuads1d(elem, tidx, tidy, comp, 0, nelem, d_U, slice);
103         ContractTransposeX1d(slice, tidx, tidy, r_t, c_B, r_V);
104         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
105       }
106     }
107   }
108 }
109 
110 inline __device__ void grad1d(const CeedInt nelem, const int transpose,
111                               const CeedScalar *c_B, const CeedScalar *c_G,
112                               const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V,
113                               CeedScalar *slice) {
114   CeedScalar r_U;
115   CeedScalar r_V;
116 
117   const int tidx = threadIdx.x;
118   const int tidy = threadIdx.y;//=>this is really a nb of elements per block
119   int dim;
120 
121   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
122        elem += gridDim.x*blockDim.z) {
123     for(int comp=0; comp<BASIS_NCOMP; comp++) {
124       if(!transpose) {
125         readDofs1d(elem, tidx, tidy, comp, nelem, d_U, slice);
126         ContractX1d(slice, tidx, tidy, r_U, c_G, r_V);
127         dim = 0;
128         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
129       } else {
130         dim = 0;
131         readQuads1d(elem, tidx, tidy, comp, dim, nelem, d_U, slice);
132         ContractTransposeX1d(slice, tidx, tidy, r_U, c_G, r_V);
133         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
134       }
135     }
136   }
137 }
138 //////////
139 //  2D  //
140 //////////
141 
142 inline __device__ void readDofs2d(const int elem, const int tidx,
143                                   const int tidy, const int comp,
144                                   const int nelem, const CeedScalar *d_U, CeedScalar &U) {
145   U = (tidx<P1D
146        && tidy<P1D) ? d_U[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D ] :
147       0.0;
148 }
149 
150 inline __device__ void writeDofs2d(const int elem, const int tidx,
151                                    const int tidy, const int comp,
152                                    const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
153   if (tidx<P1D && tidy<P1D) {
154     d_V[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D ] = r_V;
155   }
156 }
157 
158 inline __device__ void readQuads2d(const int elem, const int tidx,
159                                    const int tidy, const int comp,
160                                    const int dim, const int nelem, const CeedScalar *d_U, CeedScalar &U ) {
161   U = d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
162                dim*BASIS_NCOMP*nelem*Q1D*Q1D];
163 }
164 
165 inline __device__ void writeQuads2d(const int elem, const int tidx,
166                                     const int tidy, const int comp,
167                                     const int dim, const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
168   d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
169            dim*BASIS_NCOMP*nelem*Q1D*Q1D ] = r_V;
170 }
171 
172 inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
173                                    const int tidy,
174                                    const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
175   slice[tidx+tidy*Q1D] = U;
176   __syncthreads();
177   V = 0.0;
178   for (int i = 0; i < P1D; ++i) {
179     V += B[i + tidx*P1D] * slice[i + tidy*Q1D];//contract x direction
180   }
181   __syncthreads();
182 }
183 
184 inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
185                                    const int tidy,
186                                    const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
187   slice[tidx+tidy*Q1D] = U;
188   __syncthreads();
189   V = 0.0;
190   for (int i = 0; i < P1D; ++i) {
191     V += B[i + tidy*P1D] * slice[tidx + i*Q1D];//contract y direction
192   }
193   __syncthreads();
194 }
195 
196 inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
197     const int tidy,
198     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
199   slice[tidx+tidy*Q1D] = U;
200   __syncthreads();
201   V = 0.0;
202   if (tidy<P1D) {
203     for (int i = 0; i < Q1D; ++i) {
204       V += B[tidy + i*P1D] * slice[tidx + i*Q1D];//contract y direction
205     }
206   }
207   __syncthreads();
208 }
209 
210 inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
211     const int tidy,
212     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
213   slice[tidx+tidy*Q1D] = U;
214   __syncthreads();
215   V = 0.0;
216   if (tidx<P1D) {
217     for (int i = 0; i < Q1D; ++i) {
218       V += B[tidx + i*P1D] * slice[i + tidy*Q1D];//contract x direction
219     }
220   }
221   __syncthreads();
222 }
223 
224 inline __device__ void interp2d(const CeedInt nelem, const int transpose,
225                                 const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
226                                 CeedScalar *__restrict__ d_V,
227                                 CeedScalar *slice) {
228   CeedScalar r_V;
229   CeedScalar r_t;
230 
231   const int tidx = threadIdx.x;
232   const int tidy = threadIdx.y;
233 
234   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
235        elem += gridDim.x*blockDim.z) {
236     for(int comp=0; comp<BASIS_NCOMP; comp++) {
237       r_V = 0.0;
238       r_t = 0.0;
239       if(!transpose) {
240         readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
241         ContractX2d(slice, tidx, tidy, r_V, c_B, r_t);
242         ContractY2d(slice, tidx, tidy, r_t, c_B, r_V);
243         writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
244       } else {
245         readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
246         ContractTransposeY2d(slice, tidx, tidy, r_V, c_B, r_t);
247         ContractTransposeX2d(slice, tidx, tidy, r_t, c_B, r_V);
248         writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
249       }
250     }
251   }
252 }
253 
254 inline __device__ void grad2d(const CeedInt nelem, const int transpose,
255                               const CeedScalar *c_B, const CeedScalar *c_G,
256                               const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V,
257                               CeedScalar *slice) {
258   CeedScalar r_U;
259   CeedScalar r_V;
260   CeedScalar r_t;
261 
262   const int tidx = threadIdx.x;
263   const int tidy = threadIdx.y;
264   int dim;
265 
266   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
267        elem += gridDim.x*blockDim.z) {
268     for(int comp=0; comp<BASIS_NCOMP; comp++) {
269       if(!transpose) {
270         readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
271         ContractX2d(slice, tidx, tidy, r_U, c_G, r_t);
272         ContractY2d(slice, tidx, tidy, r_t, c_B, r_V);
273         dim = 0;
274         writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
275         ContractX2d(slice, tidx, tidy, r_U, c_B, r_t);
276         ContractY2d(slice, tidx, tidy, r_t, c_G, r_V);
277         dim = 1;
278         writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
279       } else {
280         dim = 0;
281         readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
282         ContractTransposeY2d(slice, tidx, tidy, r_U, c_B, r_t);
283         ContractTransposeX2d(slice, tidx, tidy, r_t, c_G, r_V);
284         dim = 1;
285         readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
286         ContractTransposeY2d(slice, tidx, tidy, r_U, c_G, r_t);
287         ContractTransposeX2d(slice, tidx, tidy, r_t, c_B, r_U);
288         r_V+=r_U;
289         writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
290       }
291     }
292   }
293 }
294 //////////
295 //  3D  //
296 //////////
297 
298 inline __device__ void readDofs3d(const int elem, const int tidx,
299                                   const int tidy, const int comp,
300                                   const int nelem, const CeedScalar *d_U, CeedScalar *r_U) {
301   for (int i = 0; i < P1D; i++)
302     r_U[i] = (tidx<P1D
303               && tidy<P1D) ? d_U[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
304                                       elem*BASIS_NCOMP*P1D*P1D*P1D ] : 0.0;
305   for (int i = P1D; i < Q1D; i++)
306     r_U[i] = 0.0;
307 }
308 
309 inline __device__ void readQuads3d(const int elem, const int tidx,
310                                    const int tidy, const int comp,
311                                    const int dim, const int nelem, const CeedScalar *d_U, CeedScalar *r_U) {
312   for (int i = 0; i < Q1D; i++)
313     r_U[i] = d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
314                  comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D];
315 }
316 
317 inline __device__ void writeDofs3d(const int elem, const int tidx,
318                                    const int tidy, const int comp,
319                                    const int nelem, const CeedScalar *r_V, CeedScalar *d_V) {
320   if (tidx<P1D && tidy<P1D) {
321     for (int i = 0; i < P1D; i++)
322       d_V[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
323           elem*BASIS_NCOMP*P1D*P1D*P1D ] = r_V[i];
324   }
325 }
326 
327 inline __device__ void writeQuads3d(const int elem, const int tidx,
328                                     const int tidy, const int comp,
329                                     const int dim, const int nelem, const CeedScalar *r_V, CeedScalar *d_V) {
330   for (int i = 0; i < Q1D; i++)
331     d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
332         dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D ] = r_V[i];
333 }
334 
335 inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
336                                    const int tidy,
337                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
338   for (int k = 0; k < P1D; ++k) {
339     slice[tidx+tidy*Q1D] = U[k];
340     __syncthreads();
341     V[k] = 0.0;
342     for (int i = 0; i < P1D; ++i) {
343       V[k] += B[i + tidx*P1D] * slice[i + tidy*Q1D];//contract x direction
344     }
345     __syncthreads();
346   }
347 }
348 
349 inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
350                                    const int tidy,
351                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
352   for (int k = 0; k < P1D; ++k) {
353     slice[tidx+tidy*Q1D] = U[k];
354     __syncthreads();
355     V[k] = 0.0;
356     for (int i = 0; i < P1D; ++i) {
357       V[k] += B[i + tidy*P1D] * slice[tidx + i*Q1D];//contract y direction
358     }
359     __syncthreads();
360   }
361 }
362 
363 inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
364                                    const int tidy,
365                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
366   for (int k = 0; k < Q1D; ++k) {
367     V[k] = 0.0;
368     for (int i = 0; i < P1D; ++i) {
369       V[k] += B[i + k*P1D] * U[i];//contract z direction
370     }
371   }
372 }
373 
374 inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
375     const int tidy,
376     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
377   for (int k = 0; k < Q1D; ++k) {
378     V[k] = 0.0;
379     if (k<P1D) {
380       for (int i = 0; i < Q1D; ++i) {
381         V[k] += B[k + i*P1D] * U[i];//contract z direction
382       }
383     }
384   }
385 }
386 
387 inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
388     const int tidy,
389     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
390   for (int k = 0; k < P1D; ++k) {
391     slice[tidx+tidy*Q1D] = U[k];
392     __syncthreads();
393     V[k] = 0.0;
394     if (tidy<P1D) {
395       for (int i = 0; i < Q1D; ++i) {
396         V[k] += B[tidy + i*P1D] * slice[tidx + i*Q1D];//contract y direction
397       }
398     }
399     __syncthreads();
400   }
401 }
402 
403 inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
404     const int tidy,
405     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
406   for (int k = 0; k < P1D; ++k) {
407     slice[tidx+tidy*Q1D] = U[k];
408     __syncthreads();
409     V[k] = 0.0;
410     if (tidx<P1D) {
411       for (int i = 0; i < Q1D; ++i) {
412         V[k] += B[tidx + i*P1D] * slice[i + tidy*Q1D];//contract x direction
413       }
414     }
415     __syncthreads();
416   }
417 }
418 
419 inline __device__ void interp3d(const CeedInt nelem, const int transpose,
420                                 const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
421                                 CeedScalar *__restrict__ d_V,
422                                 CeedScalar *slice) {
423   CeedScalar r_V[Q1D];
424   CeedScalar r_t[Q1D];
425 
426   const int tidx = threadIdx.x;
427   const int tidy = threadIdx.y;
428 
429   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
430        elem += gridDim.x*blockDim.z) {
431     for(int comp=0; comp<BASIS_NCOMP; comp++) {
432       for (int i = 0; i < Q1D; ++i) {
433         r_V[i] = 0.0;
434         r_t[i] = 0.0;
435       }
436       if(!transpose) {
437         readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
438         ContractX3d(slice, tidx, tidy, r_V, c_B, r_t);
439         ContractY3d(slice, tidx, tidy, r_t, c_B, r_V);
440         ContractZ3d(slice, tidx, tidy, r_V, c_B, r_t);
441         writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
442       } else {
443         readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
444         ContractTransposeZ3d(slice, tidx, tidy, r_V, c_B, r_t);
445         ContractTransposeY3d(slice, tidx, tidy, r_t, c_B, r_V);
446         ContractTransposeX3d(slice, tidx, tidy, r_V, c_B, r_t);
447         writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
448       }
449     }
450   }
451 }
452 
453 inline __device__ void grad3d(const CeedInt nelem, const int transpose,
454                               const CeedScalar *c_B, const CeedScalar *c_G,
455                               const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V,
456                               CeedScalar *slice) {
457   //use P1D for one of these
458   CeedScalar r_U[Q1D];
459   CeedScalar r_V[Q1D];
460   CeedScalar r_t[Q1D];
461 
462   const int tidx = threadIdx.x;
463   const int tidy = threadIdx.y;
464   int dim;
465 
466   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
467        elem += gridDim.x*blockDim.z) {
468     for(int comp=0; comp<BASIS_NCOMP; comp++) {
469       if(!transpose) {
470         readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
471         ContractX3d(slice, tidx, tidy, r_U, c_G, r_V);
472         ContractY3d(slice, tidx, tidy, r_V, c_B, r_t);
473         ContractZ3d(slice, tidx, tidy, r_t, c_B, r_V);
474         dim = 0;
475         writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
476         ContractX3d(slice, tidx, tidy, r_U, c_B, r_V);
477         ContractY3d(slice, tidx, tidy, r_V, c_G, r_t);
478         ContractZ3d(slice, tidx, tidy, r_t, c_B, r_V);
479         dim = 1;
480         writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
481         ContractX3d(slice, tidx, tidy, r_U, c_B, r_V);
482         ContractY3d(slice, tidx, tidy, r_V, c_B, r_t);
483         ContractZ3d(slice, tidx, tidy, r_t, c_G, r_V);
484         dim = 2;
485         writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
486       } else {
487         dim = 0;
488         readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
489         ContractTransposeZ3d(slice, tidx, tidy, r_U, c_B, r_t);
490         ContractTransposeY3d(slice, tidx, tidy, r_t, c_B, r_U);
491         ContractTransposeX3d(slice, tidx, tidy, r_U, c_G, r_V);
492         dim = 1;
493         readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
494         ContractTransposeZ3d(slice, tidx, tidy, r_U, c_B, r_t);
495         ContractTransposeY3d(slice, tidx, tidy, r_t, c_G, r_U);
496         ContractTransposeX3d(slice, tidx, tidy, r_U, c_B, r_t);
497         add(r_V, r_t);
498         dim = 2;
499         readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
500         ContractTransposeZ3d(slice, tidx, tidy, r_U, c_G, r_t);
501         ContractTransposeY3d(slice, tidx, tidy, r_t, c_B, r_U);
502         ContractTransposeX3d(slice, tidx, tidy, r_U, c_B, r_t);
503         add(r_V, r_t);
504         writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
505       }
506     }
507   }
508 }
509 
510 /////////////
511 // Kernels //
512 /////////////
513 extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
514                                   const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
515                                   CeedScalar *__restrict__ d_V) {
516   // __shared__ double slice[Q1D*Q1D];//Fix me if ElemPerBlock>1
517   extern __shared__ double slice[];
518   if (BASIS_DIM==1) {
519     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
520   } else if (BASIS_DIM==2) {
521     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
522   } else if (BASIS_DIM==3) {
523     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
524   }
525 }
526 
527 extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
528                                 const CeedScalar *c_B, const CeedScalar *c_G,
529                                 const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) {
530   // __shared__ double slice[Q1D*Q1D];//Fix me if ElemPerBlock>1
531   extern __shared__ double slice[];
532   if (BASIS_DIM==1) {
533     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
534   } else if (BASIS_DIM==2) {
535     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
536   } else if (BASIS_DIM==3) {
537     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
538   }
539 }
540 
541 /////////////
542 // Weights //
543 /////////////
544 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
545                          CeedScalar *w) {
546   const int tid = threadIdx.x;
547   const CeedScalar weight = qweight1d[tid];
548   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
549        elem += gridDim.x*blockDim.y) {
550     const int ind = elem*Q1D + tid;
551     w[ind] = weight;
552   }
553 }
554 
555 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
556                          CeedScalar *w) {
557   const int i = threadIdx.x;
558   const int j = threadIdx.y;
559   const CeedScalar weight = qweight1d[i]*qweight1d[j];
560   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
561        elem += gridDim.x*blockDim.z) {
562     const int ind = elem*Q1D*Q1D + i + j*Q1D;
563     w[ind] = weight;
564   }
565 }
566 
567 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
568                          CeedScalar *w) {
569   const int i = threadIdx.x;
570   const int j = threadIdx.y;
571   const int k = threadIdx.z;
572   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
573   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
574     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
575     w[ind] = weight;
576   }
577 }
578 
579 extern "C" __global__ void weight(const CeedInt nelem,
580                                   const CeedScalar *__restrict__ qweight1d, CeedScalar *__restrict__ v) {
581   if (BASIS_DIM==1) {
582     weight1d(nelem, qweight1d, v);
583   } else if (BASIS_DIM==2) {
584     weight2d(nelem, qweight1d, v);
585   } else if (BASIS_DIM==3) {
586     weight3d(nelem, qweight1d, v);
587   }
588 }
589 
590                                    );
591 
592 int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
593                        CeedScalar **c_B);
594 int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
595                            CeedInt Q1d, CeedScalar **c_B_ptr, CeedScalar **c_G_ptr);
596 
597 int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem,
598                                      CeedTransposeMode tmode,
599                                      CeedEvalMode emode, CeedVector u, CeedVector v) {
600   int ierr;
601   Ceed ceed;
602   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
603   Ceed_Cuda_shared *ceed_Cuda;
604   CeedGetData(ceed, (void *) &ceed_Cuda); CeedChk(ierr);
605   CeedBasis_Cuda_shared *data;
606   CeedBasisGetData(basis, (void *)&data); CeedChk(ierr);
607   const CeedInt transpose = tmode == CEED_TRANSPOSE;
608   CeedInt dim;
609   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
610   // const int optElems[7] = {0,32,8,3,2,1,8};
611   CeedInt elemsPerBlock = 1;//basis->Q1d < 7 ? optElems[basis->Q1d] : 1;
612   CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)?
613                                      1 : 0 );
614 
615   const CeedScalar *d_u;
616   CeedScalar *d_v;
617   if(emode!=CEED_EVAL_WEIGHT) {
618     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
619   }
620   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
621 
622   if (tmode == CEED_TRANSPOSE) {
623     CeedInt length;
624     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
625     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
626   }
627   if (emode == CEED_EVAL_INTERP) {
628     //TODO: check performance difference between c_B and d_B
629     CeedInt P1d, Q1d;
630     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
631     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
632     // if (ceed_Cuda->Q1d != Q1d || ceed_Cuda->P1d != P1d)
633     // {
634     //   ceed_Cuda->Q1d = Q1d;
635     //   ceed_Cuda->P1d = P1d;
636     //   ceed_Cuda->grad = false;
637       ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
638       CeedChk(ierr);
639     // }
640     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, &d_u, &d_v};
641     // void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->d_interp1d, &d_u, &d_v};
642     if (dim==1)
643     {
644       CeedInt sharedMem = Q1d*sizeof(CeedScalar);
645       ierr = run_kernel_dim_shared(ceed, data->interp, grid, Q1d, 1, elemsPerBlock, sharedMem,
646                             interpargs);
647       CeedChk(ierr);
648     } else if (dim==2) {
649       CeedInt sharedMem = Q1d*Q1d*sizeof(CeedScalar);
650       ierr = run_kernel_dim_shared(ceed, data->interp, grid, Q1d, Q1d, elemsPerBlock, sharedMem,
651                             interpargs);
652       CeedChk(ierr);
653     } else if (dim==3) {
654       CeedInt sharedMem = Q1d*Q1d*Q1d*sizeof(CeedScalar);
655       ierr = run_kernel_dim_shared(ceed, data->interp, grid, Q1d, Q1d, elemsPerBlock, sharedMem,
656                             interpargs);
657       CeedChk(ierr);
658     }
659   } else if (emode == CEED_EVAL_GRAD) {
660     CeedInt P1d, Q1d;
661     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
662     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
663     // if (ceed_Cuda->Q1d != Q1d || ceed_Cuda->P1d != P1d || !data->grad)
664     // {
665     //   ceed_Cuda->Q1d = Q1d;
666     //   ceed_Cuda->P1d = P1d;
667     //   ceed_Cuda->grad = true;
668       ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
669                                     Q1d, &data->c_B, &data->c_G);
670       CeedChk(ierr);
671     // }
672     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, &data->c_G, &d_u, &d_v};
673     // void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->d_interp1d, &data->d_grad1d, &d_u, &d_v};
674     if (dim==1)
675     {
676       CeedInt sharedMem = Q1d*sizeof(CeedScalar);
677       ierr = run_kernel_dim_shared(ceed, data->grad, grid, Q1d, 1, elemsPerBlock, sharedMem,
678                           gradargs);
679       CeedChk(ierr);
680     } else if (dim==2) {
681       CeedInt sharedMem = Q1d*Q1d*sizeof(CeedScalar);
682       ierr = run_kernel_dim_shared(ceed, data->grad, grid, Q1d, Q1d, elemsPerBlock, sharedMem,
683                           gradargs);
684       CeedChk(ierr);
685     } else if (dim==3) {
686       CeedInt sharedMem = Q1d*Q1d*Q1d*sizeof(CeedScalar);
687       ierr = run_kernel_dim_shared(ceed, data->grad, grid, Q1d, Q1d, elemsPerBlock, sharedMem,
688                           gradargs);
689       CeedChk(ierr);
690     }
691   } else if (emode == CEED_EVAL_WEIGHT) {
692     CeedInt Q1d;
693     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
694     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
695     if(dim==1){
696       const CeedInt elemsPerBlock = 32/Q1d;
697       const CeedInt gridsize = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
698       ierr = run_kernel_dim(ceed, data->weight, gridsize, Q1d, elemsPerBlock, 1, weightargs);
699     } else if(dim==2) {
700       const CeedInt optElems = 32/(Q1d*Q1d);
701       const CeedInt elemsPerBlock = optElems>0?optElems:1;
702       const CeedInt gridsize = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
703       ierr = run_kernel_dim(ceed, data->weight, gridsize, Q1d, Q1d, elemsPerBlock, weightargs);
704     } else if(dim==3) {
705       const CeedInt gridsize = nelem;
706       ierr = run_kernel_dim(ceed, data->weight, gridsize, Q1d, Q1d, Q1d, weightargs);
707     }
708   }
709 
710   if(emode!=CEED_EVAL_WEIGHT) {
711     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
712   }
713   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
714 
715   return 0;
716 }
717 
718 static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
719   int ierr;
720   Ceed ceed;
721   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
722 
723   CeedBasis_Cuda_shared *data;
724   ierr = CeedBasisGetData(basis, (void *) &data); CeedChk(ierr);
725 
726   CeedChk_Cu(ceed, cuModuleUnload(data->module));
727 
728   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr);
729   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr);
730   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr);
731 
732   ierr = CeedFree(&data); CeedChk(ierr);
733 
734   return 0;
735 }
736 
737 int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
738                                         const CeedScalar *interp1d,
739                                         const CeedScalar *grad1d,
740                                         const CeedScalar *qref1d,
741                                         const CeedScalar *qweight1d,
742                                         CeedBasis basis) {
743   int ierr;
744   Ceed ceed;
745   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
746   CeedBasis_Cuda_shared *data;
747   ierr = CeedCalloc(1, &data); CeedChk(ierr);
748 
749   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
750   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr);
751   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
752                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
753 
754   const CeedInt iBytes = qBytes * P1d;
755   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr);
756   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
757                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
758 
759   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr);
760   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
761                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
762 
763   CeedInt ncomp;
764   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
765   ierr = compile(ceed, kernelsShared, &data->module, 7,
766                  "Q1D", Q1d,
767                  "P1D", P1d,
768                  "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
769                      Q1d : P1d, dim),
770                  "BASIS_DIM", dim,
771                  "BASIS_NCOMP", ncomp,
772                  "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
773                  "BASIS_NQPT", CeedIntPow(Q1d, dim)
774                 ); CeedChk(ierr);
775   ierr = get_kernel(ceed, data->module, "interp", &data->interp);
776   CeedChk(ierr);
777   ierr = get_kernel(ceed, data->module, "grad", &data->grad);
778   CeedChk(ierr);
779   ierr = get_kernel(ceed, data->module, "weight", &data->weight);
780   CeedChk(ierr);
781 
782   ierr = CeedBasisSetData(basis, (void *)&data);
783   CeedChk(ierr);
784   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
785                                 CeedBasisApplyTensor_Cuda_shared);
786   CeedChk(ierr);
787   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
788                                 CeedBasisDestroy_Cuda_shared);
789   CeedChk(ierr);
790   return 0;
791 }
792 
793 int CeedBasisCreateH1_Cuda_shared(CeedElemTopology topo, CeedInt dim,
794                                   CeedInt ndof, CeedInt nqpts,
795                                   const CeedScalar *interp,
796                                   const CeedScalar *grad,
797                                   const CeedScalar *qref,
798                                   const CeedScalar *qweight,
799                                   CeedBasis basis) {
800   int ierr;
801   Ceed ceed;
802   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
803   return CeedError(ceed, 1, "Backend does not implement generic H1 basis");
804 }
805