xref: /libCEED/backends/cuda-shared/ceed-cuda-shared-basis.c (revision 4247ecf36a2057295017a8f7d130aa86833ce9b5)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed-backend.h>
18 #include <ceed.h>
19 #include "ceed-cuda-shared.h"
20 #include "../cuda/ceed-cuda.h"
21 
22 //*********************
23 // shared mem kernels
24 static const char *kernelsShared = QUOTE(
25 
26 inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
27   for (int i = 0; i < Q1D; i++)
28     r_V[i] += r_U[i];
29 }
30 
31 //////////
32 //  1D  //
33 //////////
34 
35 inline __device__ void readDofs1d(const int elem, const int tidx,
36                                   const int tidy, const int comp,
37                                   const int nelem, const CeedScalar *d_U, CeedScalar *slice) {
38   for (int i = 0; i < P1D; i++)
39     slice[i] = d_U[i + comp*P1D + elem*BASIS_NCOMP*P1D];
40   for (int i = P1D; i < Q1D; i++)
41     slice[i] = 0.0;
42 }
43 
44 inline __device__ void writeDofs1d(const int elem, const int tidx,
45                                    const int tidy, const int comp,
46                                    const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
47   if (tidx<P1D) {
48     d_V[tidx + comp*P1D + elem*BASIS_NCOMP*P1D] = r_V;
49   }
50 }
51 
52 inline __device__ void readQuads1d(const int elem, const int tidx,
53                                    const int tidy, const int comp,
54                                    const int dim, const int nelem, const CeedScalar *d_U, CeedScalar *slice) {
55   for (int i = 0; i < Q1D; i++)
56     slice[i] = d_U[i + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D];
57 }
58 
59 inline __device__ void writeQuads1d(const int elem, const int tidx,
60                                     const int tidy, const int comp,
61                                     const int dim, const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
62   d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
63 }
64 
65 inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
66                                    const int tidy,
67                                    const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
68   V = 0.0;
69   for (int i = 0; i < P1D; ++i) {
70     V += B[i + tidx*P1D] * slice[i];//contract x direction
71   }
72 }
73 
74 inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
75     const int tidy,
76     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
77   V = 0.0;
78   for (int i = 0; i < Q1D; ++i) {
79     V += B[tidx + i*P1D] * slice[i];//contract x direction
80   }
81 }
82 
83 inline __device__ void interp1d(const CeedInt nelem, const int transpose,
84                                 const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
85                                 CeedScalar *__restrict__ d_V,
86                                 CeedScalar *slice) {
87   CeedScalar r_V;
88   CeedScalar r_t;
89 
90   const int tidx = threadIdx.x;
91   const int tidy = threadIdx.y;
92 
93 
94   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
95        elem += gridDim.x*blockDim.z) {
96     for(int comp=0; comp<BASIS_NCOMP; comp++) {
97       if(!transpose) {
98         readDofs1d(elem, tidx, tidy, comp, nelem, d_U, slice);
99         ContractX1d(slice, tidx, tidy, r_t, c_B, r_V);
100         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
101       } else {
102         readQuads1d(elem, tidx, tidy, comp, 0, nelem, d_U, slice);
103         ContractTransposeX1d(slice, tidx, tidy, r_t, c_B, r_V);
104         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
105       }
106     }
107   }
108 }
109 
110 inline __device__ void grad1d(const CeedInt nelem, const int transpose,
111                               const CeedScalar *c_B, const CeedScalar *c_G,
112                               const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V,
113                               CeedScalar *slice) {
114   CeedScalar r_U;
115   CeedScalar r_V;
116 
117   const int tidx = threadIdx.x;
118   const int tidy = threadIdx.y;//=>this is really a nb of elements per block
119   int dim;
120 
121   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
122        elem += gridDim.x*blockDim.z) {
123     for(int comp=0; comp<BASIS_NCOMP; comp++) {
124       if(!transpose) {
125         readDofs1d(elem, tidx, tidy, comp, nelem, d_U, slice);
126         ContractX1d(slice, tidx, tidy, r_U, c_G, r_V);
127         dim = 0;
128         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
129       } else {
130         dim = 0;
131         readQuads1d(elem, tidx, tidy, comp, dim, nelem, d_U, slice);
132         ContractTransposeX1d(slice, tidx, tidy, r_U, c_G, r_V);
133         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
134       }
135     }
136   }
137 }
138 //////////
139 //  2D  //
140 //////////
141 
142 inline __device__ void readDofs2d(const int elem, const int tidx,
143                                   const int tidy, const int comp,
144                                   const int nelem, const CeedScalar *d_U, CeedScalar &U) {
145   U = (tidx<P1D
146        && tidy<P1D) ? d_U[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D ] :
147       0.0;
148 }
149 
150 inline __device__ void writeDofs2d(const int elem, const int tidx,
151                                    const int tidy, const int comp,
152                                    const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
153   if (tidx<P1D && tidy<P1D) {
154     d_V[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D ] = r_V;
155   }
156 }
157 
158 inline __device__ void readQuads2d(const int elem, const int tidx,
159                                    const int tidy, const int comp,
160                                    const int dim, const int nelem, const CeedScalar *d_U, CeedScalar &U ) {
161   U = d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
162                dim*BASIS_NCOMP*nelem*Q1D*Q1D];
163 }
164 
165 inline __device__ void writeQuads2d(const int elem, const int tidx,
166                                     const int tidy, const int comp,
167                                     const int dim, const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
168   d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
169            dim*BASIS_NCOMP*nelem*Q1D*Q1D ] = r_V;
170 }
171 
172 inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
173                                    const int tidy, const int tidz,
174                                    const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
175   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
176   __syncthreads();
177   V = 0.0;
178   for (int i = 0; i < P1D; ++i) {
179     V += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction
180   }
181   __syncthreads();
182 }
183 
184 inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
185                                    const int tidy, const int tidz,
186                                    const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
187   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
188   __syncthreads();
189   V = 0.0;
190   for (int i = 0; i < P1D; ++i) {
191     V += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction
192   }
193   __syncthreads();
194 }
195 
196 inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
197     const int tidy, const int tidz,
198     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
199   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
200   __syncthreads();
201   V = 0.0;
202   if (tidy<P1D) {
203     for (int i = 0; i < Q1D; ++i) {
204       V += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction
205     }
206   }
207   __syncthreads();
208 }
209 
210 inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
211     const int tidy, const int tidz,
212     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
213   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
214   __syncthreads();
215   V = 0.0;
216   if (tidx<P1D) {
217     for (int i = 0; i < Q1D; ++i) {
218       V += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction
219     }
220   }
221   __syncthreads();
222 }
223 
224 inline __device__ void interp2d(const CeedInt nelem, const int transpose,
225                                 const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
226                                 CeedScalar *__restrict__ d_V,
227                                 CeedScalar *slice) {
228   CeedScalar r_V;
229   CeedScalar r_t;
230 
231   const int tidx = threadIdx.x;
232   const int tidy = threadIdx.y;
233   const int tidz = threadIdx.z;
234   const int blockElem = tidz/BASIS_NCOMP;
235   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
236   const int comp = tidz%BASIS_NCOMP;
237 
238   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
239        elem += gridDim.x*elemsPerBlock) {
240     // for(int comp=0; comp<BASIS_NCOMP; comp++) {
241       const int comp = tidz%BASIS_NCOMP;
242       r_V = 0.0;
243       r_t = 0.0;
244       if(!transpose) {
245         readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
246         ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
247         ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
248         writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
249       } else {
250         readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
251         ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
252         ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
253         writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
254       }
255     // }
256   }
257 }
258 
259 inline __device__ void grad2d(const CeedInt nelem, const int transpose,
260                               const CeedScalar *c_B, const CeedScalar *c_G,
261                               const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V,
262                               CeedScalar *slice) {
263   CeedScalar r_U;
264   CeedScalar r_V;
265   CeedScalar r_t;
266 
267   const int tidx = threadIdx.x;
268   const int tidy = threadIdx.y;
269   const int tidz = threadIdx.z;
270   const int blockElem = tidz/BASIS_NCOMP;
271   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
272   const int comp = tidz%BASIS_NCOMP;
273   int dim;
274 
275   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
276        elem += gridDim.x*elemsPerBlock) {
277     // for(int comp=0; comp<BASIS_NCOMP; comp++) {
278       if(!transpose) {
279         readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
280         ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
281         ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
282         dim = 0;
283         writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
284         ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
285         ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
286         dim = 1;
287         writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
288       } else {
289         dim = 0;
290         readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
291         ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
292         ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
293         dim = 1;
294         readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
295         ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
296         ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
297         r_V+=r_U;
298         writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
299       }
300     // }
301   }
302 }
303 //////////
304 //  3D  //
305 //////////
306 
307 inline __device__ void readDofs3d(const int elem, const int tidx,
308                                   const int tidy, const int comp,
309                                   const int nelem, const CeedScalar *d_U, CeedScalar *r_U) {
310   for (int i = 0; i < P1D; i++)
311     r_U[i] = (tidx<P1D
312               && tidy<P1D) ? d_U[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
313                                       elem*BASIS_NCOMP*P1D*P1D*P1D ] : 0.0;
314   for (int i = P1D; i < Q1D; i++)
315     r_U[i] = 0.0;
316 }
317 
318 inline __device__ void readQuads3d(const int elem, const int tidx,
319                                    const int tidy, const int comp,
320                                    const int dim, const int nelem, const CeedScalar *d_U, CeedScalar *r_U) {
321   for (int i = 0; i < Q1D; i++)
322     r_U[i] = d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
323                  comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D];
324 }
325 
326 inline __device__ void writeDofs3d(const int elem, const int tidx,
327                                    const int tidy, const int comp,
328                                    const int nelem, const CeedScalar *r_V, CeedScalar *d_V) {
329   if (tidx<P1D && tidy<P1D) {
330     for (int i = 0; i < P1D; i++)
331       d_V[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
332           elem*BASIS_NCOMP*P1D*P1D*P1D ] = r_V[i];
333   }
334 }
335 
336 inline __device__ void writeQuads3d(const int elem, const int tidx,
337                                     const int tidy, const int comp,
338                                     const int dim, const int nelem, const CeedScalar *r_V, CeedScalar *d_V) {
339   for (int i = 0; i < Q1D; i++)
340     d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
341         dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D ] = r_V[i];
342 }
343 
344 inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
345                                    const int tidy,
346                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
347   for (int k = 0; k < P1D; ++k) {
348     slice[tidx+tidy*Q1D] = U[k];
349     __syncthreads();
350     V[k] = 0.0;
351     for (int i = 0; i < P1D; ++i) {
352       V[k] += B[i + tidx*P1D] * slice[i + tidy*Q1D];//contract x direction
353     }
354     __syncthreads();
355   }
356 }
357 
358 inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
359                                    const int tidy,
360                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
361   for (int k = 0; k < P1D; ++k) {
362     slice[tidx+tidy*Q1D] = U[k];
363     __syncthreads();
364     V[k] = 0.0;
365     for (int i = 0; i < P1D; ++i) {
366       V[k] += B[i + tidy*P1D] * slice[tidx + i*Q1D];//contract y direction
367     }
368     __syncthreads();
369   }
370 }
371 
372 inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
373                                    const int tidy,
374                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
375   for (int k = 0; k < Q1D; ++k) {
376     V[k] = 0.0;
377     for (int i = 0; i < P1D; ++i) {
378       V[k] += B[i + k*P1D] * U[i];//contract z direction
379     }
380   }
381 }
382 
383 inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
384     const int tidy,
385     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
386   for (int k = 0; k < Q1D; ++k) {
387     V[k] = 0.0;
388     if (k<P1D) {
389       for (int i = 0; i < Q1D; ++i) {
390         V[k] += B[k + i*P1D] * U[i];//contract z direction
391       }
392     }
393   }
394 }
395 
396 inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
397     const int tidy,
398     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
399   for (int k = 0; k < P1D; ++k) {
400     slice[tidx+tidy*Q1D] = U[k];
401     __syncthreads();
402     V[k] = 0.0;
403     if (tidy<P1D) {
404       for (int i = 0; i < Q1D; ++i) {
405         V[k] += B[tidy + i*P1D] * slice[tidx + i*Q1D];//contract y direction
406       }
407     }
408     __syncthreads();
409   }
410 }
411 
412 inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
413     const int tidy,
414     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
415   for (int k = 0; k < P1D; ++k) {
416     slice[tidx+tidy*Q1D] = U[k];
417     __syncthreads();
418     V[k] = 0.0;
419     if (tidx<P1D) {
420       for (int i = 0; i < Q1D; ++i) {
421         V[k] += B[tidx + i*P1D] * slice[i + tidy*Q1D];//contract x direction
422       }
423     }
424     __syncthreads();
425   }
426 }
427 
428 inline __device__ void interp3d(const CeedInt nelem, const int transpose,
429                                 const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
430                                 CeedScalar *__restrict__ d_V,
431                                 CeedScalar *slice) {
432   CeedScalar r_V[Q1D];
433   CeedScalar r_t[Q1D];
434 
435   const int tidx = threadIdx.x;
436   const int tidy = threadIdx.y;
437 
438   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
439        elem += gridDim.x*blockDim.z) {
440     for(int comp=0; comp<BASIS_NCOMP; comp++) {
441       for (int i = 0; i < Q1D; ++i) {
442         r_V[i] = 0.0;
443         r_t[i] = 0.0;
444       }
445       if(!transpose) {
446         readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
447         ContractX3d(slice, tidx, tidy, r_V, c_B, r_t);
448         ContractY3d(slice, tidx, tidy, r_t, c_B, r_V);
449         ContractZ3d(slice, tidx, tidy, r_V, c_B, r_t);
450         writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
451       } else {
452         readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
453         ContractTransposeZ3d(slice, tidx, tidy, r_V, c_B, r_t);
454         ContractTransposeY3d(slice, tidx, tidy, r_t, c_B, r_V);
455         ContractTransposeX3d(slice, tidx, tidy, r_V, c_B, r_t);
456         writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
457       }
458     }
459   }
460 }
461 
462 inline __device__ void grad3d(const CeedInt nelem, const int transpose,
463                               const CeedScalar *c_B, const CeedScalar *c_G,
464                               const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V,
465                               CeedScalar *slice) {
466   //use P1D for one of these
467   CeedScalar r_U[Q1D];
468   CeedScalar r_V[Q1D];
469   CeedScalar r_t[Q1D];
470 
471   const int tidx = threadIdx.x;
472   const int tidy = threadIdx.y;
473   int dim;
474 
475   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
476        elem += gridDim.x*blockDim.z) {
477     for(int comp=0; comp<BASIS_NCOMP; comp++) {
478       if(!transpose) {
479         readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
480         ContractX3d(slice, tidx, tidy, r_U, c_G, r_V);
481         ContractY3d(slice, tidx, tidy, r_V, c_B, r_t);
482         ContractZ3d(slice, tidx, tidy, r_t, c_B, r_V);
483         dim = 0;
484         writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
485         ContractX3d(slice, tidx, tidy, r_U, c_B, r_V);
486         ContractY3d(slice, tidx, tidy, r_V, c_G, r_t);
487         ContractZ3d(slice, tidx, tidy, r_t, c_B, r_V);
488         dim = 1;
489         writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
490         ContractX3d(slice, tidx, tidy, r_U, c_B, r_V);
491         ContractY3d(slice, tidx, tidy, r_V, c_B, r_t);
492         ContractZ3d(slice, tidx, tidy, r_t, c_G, r_V);
493         dim = 2;
494         writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
495       } else {
496         dim = 0;
497         readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
498         ContractTransposeZ3d(slice, tidx, tidy, r_U, c_B, r_t);
499         ContractTransposeY3d(slice, tidx, tidy, r_t, c_B, r_U);
500         ContractTransposeX3d(slice, tidx, tidy, r_U, c_G, r_V);
501         dim = 1;
502         readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
503         ContractTransposeZ3d(slice, tidx, tidy, r_U, c_B, r_t);
504         ContractTransposeY3d(slice, tidx, tidy, r_t, c_G, r_U);
505         ContractTransposeX3d(slice, tidx, tidy, r_U, c_B, r_t);
506         add(r_V, r_t);
507         dim = 2;
508         readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
509         ContractTransposeZ3d(slice, tidx, tidy, r_U, c_G, r_t);
510         ContractTransposeY3d(slice, tidx, tidy, r_t, c_B, r_U);
511         ContractTransposeX3d(slice, tidx, tidy, r_U, c_B, r_t);
512         add(r_V, r_t);
513         writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
514       }
515     }
516   }
517 }
518 
519 /////////////
520 // Kernels //
521 /////////////
522 extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
523                                   const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
524                                   CeedScalar *__restrict__ d_V) {
525   // __shared__ double slice[Q1D*Q1D];//Fix me if ElemPerBlock>1
526   extern __shared__ double slice[];
527   if (BASIS_DIM==1) {
528     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
529   } else if (BASIS_DIM==2) {
530     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
531   } else if (BASIS_DIM==3) {
532     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
533   }
534 }
535 
536 extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
537                                 const CeedScalar *c_B, const CeedScalar *c_G,
538                                 const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) {
539   // __shared__ double slice[Q1D*Q1D];//Fix me if ElemPerBlock>1
540   extern __shared__ double slice[];
541   if (BASIS_DIM==1) {
542     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
543   } else if (BASIS_DIM==2) {
544     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
545   } else if (BASIS_DIM==3) {
546     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
547   }
548 }
549 
550 /////////////
551 // Weights //
552 /////////////
553 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
554                          CeedScalar *w) {
555   const int tid = threadIdx.x;
556   const CeedScalar weight = qweight1d[tid];
557   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
558        elem += gridDim.x*blockDim.y) {
559     const int ind = elem*Q1D + tid;
560     w[ind] = weight;
561   }
562 }
563 
564 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
565                          CeedScalar *w) {
566   const int i = threadIdx.x;
567   const int j = threadIdx.y;
568   const CeedScalar weight = qweight1d[i]*qweight1d[j];
569   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
570        elem += gridDim.x*blockDim.z) {
571     const int ind = elem*Q1D*Q1D + i + j*Q1D;
572     w[ind] = weight;
573   }
574 }
575 
576 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
577                          CeedScalar *w) {
578   const int i = threadIdx.x;
579   const int j = threadIdx.y;
580   const int k = threadIdx.z;
581   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
582   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
583     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
584     w[ind] = weight;
585   }
586 }
587 
588 extern "C" __global__ void weight(const CeedInt nelem,
589                                   const CeedScalar *__restrict__ qweight1d, CeedScalar *__restrict__ v) {
590   if (BASIS_DIM==1) {
591     weight1d(nelem, qweight1d, v);
592   } else if (BASIS_DIM==2) {
593     weight2d(nelem, qweight1d, v);
594   } else if (BASIS_DIM==3) {
595     weight3d(nelem, qweight1d, v);
596   }
597 }
598 
599                                    );
600 
601 int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
602                        CeedScalar **c_B);
603 int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
604                            CeedInt Q1d, CeedScalar **c_B_ptr, CeedScalar **c_G_ptr);
605 
606 int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem,
607                                      CeedTransposeMode tmode,
608                                      CeedEvalMode emode, CeedVector u, CeedVector v) {
609   int ierr;
610   Ceed ceed;
611   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
612   Ceed_Cuda_shared *ceed_Cuda;
613   CeedGetData(ceed, (void *) &ceed_Cuda); CeedChk(ierr);
614   CeedBasis_Cuda_shared *data;
615   CeedBasisGetData(basis, (void *)&data); CeedChk(ierr);
616   const CeedInt transpose = tmode == CEED_TRANSPOSE;
617   CeedInt dim, ncomp;
618   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
619   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
620 
621   const CeedScalar *d_u;
622   CeedScalar *d_v;
623   if(emode!=CEED_EVAL_WEIGHT) {
624     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
625   }
626   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
627 
628   if (tmode == CEED_TRANSPOSE) {
629     CeedInt length;
630     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
631     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
632   }
633   if (emode == CEED_EVAL_INTERP) {
634     //TODO: check performance difference between c_B and d_B
635     CeedInt P1d, Q1d;
636     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
637     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
638     // if (ceed_Cuda->Q1d != Q1d || ceed_Cuda->P1d != P1d)
639     // {
640     //   ceed_Cuda->Q1d = Q1d;
641     //   ceed_Cuda->P1d = P1d;
642     //   ceed_Cuda->grad = false;
643       ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
644       CeedChk(ierr);
645     // }
646     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, &d_u, &d_v};
647     // void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->d_interp1d, &d_u, &d_v};
648     if (dim==1)
649     {
650       CeedInt elemsPerBlock = 1;
651       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
652       CeedInt sharedMem = Q1d*sizeof(CeedScalar);
653       ierr = run_kernel_dim_shared(ceed, data->interp, grid, Q1d, 1, elemsPerBlock, sharedMem,
654                             interpargs);
655       CeedChk(ierr);
656     } else if (dim==2) {
657       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
658       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
659       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
660       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
661       ierr = run_kernel_dim_shared(ceed, data->interp, grid, Q1d, Q1d, ncomp*elemsPerBlock, sharedMem,
662                             interpargs);
663       CeedChk(ierr);
664     } else if (dim==3) {
665       CeedInt elemsPerBlock = 1;
666       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
667       CeedInt sharedMem = Q1d*Q1d*Q1d*sizeof(CeedScalar);
668       ierr = run_kernel_dim_shared(ceed, data->interp, grid, Q1d, Q1d, elemsPerBlock, sharedMem,
669                             interpargs);
670       CeedChk(ierr);
671     }
672   } else if (emode == CEED_EVAL_GRAD) {
673     CeedInt P1d, Q1d;
674     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
675     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
676     // if (ceed_Cuda->Q1d != Q1d || ceed_Cuda->P1d != P1d || !data->grad)
677     // {
678     //   ceed_Cuda->Q1d = Q1d;
679     //   ceed_Cuda->P1d = P1d;
680     //   ceed_Cuda->grad = true;
681       ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
682                                     Q1d, &data->c_B, &data->c_G);
683       CeedChk(ierr);
684     // }
685     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, &data->c_G, &d_u, &d_v};
686     // void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->d_interp1d, &data->d_grad1d, &d_u, &d_v};
687     if (dim==1)
688     {
689       CeedInt elemsPerBlock = 1;
690       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
691       CeedInt sharedMem = Q1d*sizeof(CeedScalar);
692       ierr = run_kernel_dim_shared(ceed, data->grad, grid, Q1d, 1, elemsPerBlock, sharedMem,
693                           gradargs);
694       CeedChk(ierr);
695     } else if (dim==2) {
696       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
697       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
698       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
699       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
700       ierr = run_kernel_dim_shared(ceed, data->grad, grid, Q1d, Q1d, ncomp*elemsPerBlock, sharedMem,
701                           gradargs);
702       CeedChk(ierr);
703     } else if (dim==3) {
704       CeedInt elemsPerBlock = 1;
705       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
706       CeedInt sharedMem = Q1d*Q1d*Q1d*sizeof(CeedScalar);
707       ierr = run_kernel_dim_shared(ceed, data->grad, grid, Q1d, Q1d, elemsPerBlock, sharedMem,
708                           gradargs);
709       CeedChk(ierr);
710     }
711   } else if (emode == CEED_EVAL_WEIGHT) {
712     CeedInt Q1d;
713     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
714     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
715     if(dim==1){
716       const CeedInt elemsPerBlock = 32/Q1d;
717       const CeedInt gridsize = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
718       ierr = run_kernel_dim(ceed, data->weight, gridsize, Q1d, elemsPerBlock, 1, weightargs);
719     } else if(dim==2) {
720       const CeedInt optElems = 32/(Q1d*Q1d);
721       const CeedInt elemsPerBlock = optElems>0?optElems:1;
722       const CeedInt gridsize = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
723       ierr = run_kernel_dim(ceed, data->weight, gridsize, Q1d, Q1d, elemsPerBlock, weightargs);
724     } else if(dim==3) {
725       const CeedInt gridsize = nelem;
726       ierr = run_kernel_dim(ceed, data->weight, gridsize, Q1d, Q1d, Q1d, weightargs);
727     }
728   }
729 
730   if(emode!=CEED_EVAL_WEIGHT) {
731     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
732   }
733   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
734 
735   return 0;
736 }
737 
738 static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
739   int ierr;
740   Ceed ceed;
741   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
742 
743   CeedBasis_Cuda_shared *data;
744   ierr = CeedBasisGetData(basis, (void *) &data); CeedChk(ierr);
745 
746   CeedChk_Cu(ceed, cuModuleUnload(data->module));
747 
748   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr);
749   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr);
750   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr);
751 
752   ierr = CeedFree(&data); CeedChk(ierr);
753 
754   return 0;
755 }
756 
757 int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
758                                         const CeedScalar *interp1d,
759                                         const CeedScalar *grad1d,
760                                         const CeedScalar *qref1d,
761                                         const CeedScalar *qweight1d,
762                                         CeedBasis basis) {
763   int ierr;
764   Ceed ceed;
765   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
766   CeedBasis_Cuda_shared *data;
767   ierr = CeedCalloc(1, &data); CeedChk(ierr);
768 
769   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
770   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr);
771   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
772                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
773 
774   const CeedInt iBytes = qBytes * P1d;
775   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr);
776   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
777                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
778 
779   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr);
780   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
781                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
782 
783   CeedInt ncomp;
784   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
785   ierr = compile(ceed, kernelsShared, &data->module, 7,
786                  "Q1D", Q1d,
787                  "P1D", P1d,
788                  "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
789                      Q1d : P1d, dim),
790                  "BASIS_DIM", dim,
791                  "BASIS_NCOMP", ncomp,
792                  "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
793                  "BASIS_NQPT", CeedIntPow(Q1d, dim)
794                 ); CeedChk(ierr);
795   ierr = get_kernel(ceed, data->module, "interp", &data->interp);
796   CeedChk(ierr);
797   ierr = get_kernel(ceed, data->module, "grad", &data->grad);
798   CeedChk(ierr);
799   ierr = get_kernel(ceed, data->module, "weight", &data->weight);
800   CeedChk(ierr);
801 
802   ierr = CeedBasisSetData(basis, (void *)&data);
803   CeedChk(ierr);
804   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
805                                 CeedBasisApplyTensor_Cuda_shared);
806   CeedChk(ierr);
807   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
808                                 CeedBasisDestroy_Cuda_shared);
809   CeedChk(ierr);
810   return 0;
811 }
812 
813 int CeedBasisCreateH1_Cuda_shared(CeedElemTopology topo, CeedInt dim,
814                                   CeedInt ndof, CeedInt nqpts,
815                                   const CeedScalar *interp,
816                                   const CeedScalar *grad,
817                                   const CeedScalar *qref,
818                                   const CeedScalar *qweight,
819                                   CeedBasis basis) {
820   int ierr;
821   Ceed ceed;
822   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
823   return CeedError(ceed, 1, "Backend does not implement generic H1 basis");
824 }
825