xref: /libCEED/backends/cuda-shared/ceed-cuda-shared-basis.c (revision 698ebc35146dd735d5af4885299c45c460a0930c)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed-backend.h>
18 #include <ceed.h>
19 #include "ceed-cuda-shared.h"
20 #include "../cuda/ceed-cuda.h"
21 
22 //*********************
23 // shared mem kernels
24 static const char *kernelsShared = QUOTE(
25 
26 inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
27   for (int i = 0; i < Q1D; i++)
28     r_V[i] += r_U[i];
29 }
30 
31 //////////
32 //  1D  //
33 //////////
34 
35 inline __device__ void readDofs1d(const int elem, const int tidx,
36                                   const int tidy, const int tidz,const int comp,
37                                   const int nelem, const CeedScalar *d_U, CeedScalar *slice) {
38   for (int i = 0; i < P1D; i++)
39     slice[i+tidz*Q1D] = d_U[i + comp*P1D + elem*BASIS_NCOMP*P1D];
40   for (int i = P1D; i < Q1D; i++)
41     slice[i+tidz*Q1D] = 0.0;
42 }
43 
44 inline __device__ void writeDofs1d(const int elem, const int tidx,
45                                    const int tidy, const int comp,
46                                    const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
47   if (tidx<P1D) {
48     d_V[tidx + comp*P1D + elem*BASIS_NCOMP*P1D] = r_V;
49   }
50 }
51 
52 inline __device__ void readQuads1d(const int elem, const int tidx,
53                                    const int tidy, const int tidz, const int comp,
54                                    const int dim, const int nelem, const CeedScalar *d_U, CeedScalar *slice) {
55   for (int i = 0; i < Q1D; i++)
56     slice[i+tidz*Q1D] = d_U[i + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D];
57 }
58 
59 inline __device__ void writeQuads1d(const int elem, const int tidx,
60                                     const int tidy, const int comp,
61                                     const int dim, const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
62   d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
63 }
64 
65 inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
66                                    const int tidy, const int tidz,
67                                    const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
68   V = 0.0;
69   for (int i = 0; i < P1D; ++i) {
70     V += B[i + tidx*P1D] * slice[i+tidz*Q1D];//contract x direction
71   }
72 }
73 
74 inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
75     const int tidy, const int tidz,
76     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
77   V = 0.0;
78   for (int i = 0; i < Q1D; ++i) {
79     V += B[tidx + i*P1D] * slice[i+tidz*Q1D];//contract x direction
80   }
81 }
82 
83 inline __device__ void interp1d(const CeedInt nelem, const int transpose,
84                                 const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
85                                 CeedScalar *__restrict__ d_V,
86                                 CeedScalar *slice) {
87   CeedScalar r_V;
88   CeedScalar r_t;
89 
90   const int tidx = threadIdx.x;
91   const int tidy = threadIdx.y;
92   const int tidz = threadIdx.z;
93 
94 
95   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
96        elem += gridDim.x*blockDim.z) {
97     for(int comp=0; comp<BASIS_NCOMP; comp++) {
98       if(!transpose) {
99         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
100         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
101         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
102       } else {
103         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
104         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
105         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
106       }
107     }
108   }
109 }
110 
111 inline __device__ void grad1d(const CeedInt nelem, const int transpose,
112                               const CeedScalar *c_B, const CeedScalar *c_G,
113                               const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V,
114                               CeedScalar *slice) {
115   CeedScalar r_U;
116   CeedScalar r_V;
117 
118   const int tidx = threadIdx.x;
119   const int tidy = threadIdx.y;
120   const int tidz = threadIdx.z;
121   int dim;
122 
123   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
124        elem += gridDim.x*blockDim.z) {
125     for(int comp=0; comp<BASIS_NCOMP; comp++) {
126       if(!transpose) {
127         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
128         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
129         dim = 0;
130         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
131       } else {
132         dim = 0;
133         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
134         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
135         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
136       }
137     }
138   }
139 }
140 //////////
141 //  2D  //
142 //////////
143 
144 inline __device__ void readDofs2d(const int elem, const int tidx,
145                                   const int tidy, const int comp,
146                                   const int nelem, const CeedScalar *d_U, CeedScalar &U) {
147   U = (tidx<P1D
148        && tidy<P1D) ? d_U[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D ] :
149       0.0;
150 }
151 
152 inline __device__ void writeDofs2d(const int elem, const int tidx,
153                                    const int tidy, const int comp,
154                                    const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
155   if (tidx<P1D && tidy<P1D) {
156     d_V[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D ] = r_V;
157   }
158 }
159 
160 inline __device__ void readQuads2d(const int elem, const int tidx,
161                                    const int tidy, const int comp,
162                                    const int dim, const int nelem, const CeedScalar *d_U, CeedScalar &U ) {
163   U = d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
164                dim*BASIS_NCOMP*nelem*Q1D*Q1D];
165 }
166 
167 inline __device__ void writeQuads2d(const int elem, const int tidx,
168                                     const int tidy, const int comp,
169                                     const int dim, const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
170   d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
171            dim*BASIS_NCOMP*nelem*Q1D*Q1D ] = r_V;
172 }
173 
174 inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
175                                    const int tidy, const int tidz,
176                                    const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
177   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
178   __syncthreads();
179   V = 0.0;
180   for (int i = 0; i < P1D; ++i) {
181     V += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction
182   }
183   __syncthreads();
184 }
185 
186 inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
187                                    const int tidy, const int tidz,
188                                    const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
189   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
190   __syncthreads();
191   V = 0.0;
192   for (int i = 0; i < P1D; ++i) {
193     V += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction
194   }
195   __syncthreads();
196 }
197 
198 inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
199     const int tidy, const int tidz,
200     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
201   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
202   __syncthreads();
203   V = 0.0;
204   if (tidy<P1D) {
205     for (int i = 0; i < Q1D; ++i) {
206       V += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction
207     }
208   }
209   __syncthreads();
210 }
211 
212 inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
213     const int tidy, const int tidz,
214     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
215   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
216   __syncthreads();
217   V = 0.0;
218   if (tidx<P1D) {
219     for (int i = 0; i < Q1D; ++i) {
220       V += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction
221     }
222   }
223   __syncthreads();
224 }
225 
226 inline __device__ void interp2d(const CeedInt nelem, const int transpose,
227                                 const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
228                                 CeedScalar *__restrict__ d_V,
229                                 CeedScalar *slice) {
230   CeedScalar r_V;
231   CeedScalar r_t;
232 
233   const int tidx = threadIdx.x;
234   const int tidy = threadIdx.y;
235   const int tidz = threadIdx.z;
236   const int blockElem = tidz/BASIS_NCOMP;
237   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
238   const int comp = tidz%BASIS_NCOMP;
239 
240   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
241        elem += gridDim.x*elemsPerBlock) {
242     // for(int comp=0; comp<BASIS_NCOMP; comp++) {
243       const int comp = tidz%BASIS_NCOMP;
244       r_V = 0.0;
245       r_t = 0.0;
246       if(!transpose) {
247         readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
248         ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
249         ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
250         writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
251       } else {
252         readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
253         ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
254         ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
255         writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
256       }
257     // }
258   }
259 }
260 
261 inline __device__ void grad2d(const CeedInt nelem, const int transpose,
262                               const CeedScalar *c_B, const CeedScalar *c_G,
263                               const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V,
264                               CeedScalar *slice) {
265   CeedScalar r_U;
266   CeedScalar r_V;
267   CeedScalar r_t;
268 
269   const int tidx = threadIdx.x;
270   const int tidy = threadIdx.y;
271   const int tidz = threadIdx.z;
272   const int blockElem = tidz/BASIS_NCOMP;
273   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
274   const int comp = tidz%BASIS_NCOMP;
275   int dim;
276 
277   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
278        elem += gridDim.x*elemsPerBlock) {
279     // for(int comp=0; comp<BASIS_NCOMP; comp++) {
280       if(!transpose) {
281         readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
282         ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
283         ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
284         dim = 0;
285         writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
286         ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
287         ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
288         dim = 1;
289         writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
290       } else {
291         dim = 0;
292         readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
293         ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
294         ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
295         dim = 1;
296         readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
297         ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
298         ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
299         r_V+=r_U;
300         writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
301       }
302     // }
303   }
304 }
305 //////////
306 //  3D  //
307 //////////
308 
309 inline __device__ void readDofs3d(const int elem, const int tidx,
310                                   const int tidy, const int comp,
311                                   const int nelem, const CeedScalar *d_U, CeedScalar *r_U) {
312   for (int i = 0; i < P1D; i++)
313     r_U[i] = (tidx<P1D
314               && tidy<P1D) ? d_U[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
315                                       elem*BASIS_NCOMP*P1D*P1D*P1D ] : 0.0;
316   for (int i = P1D; i < Q1D; i++)
317     r_U[i] = 0.0;
318 }
319 
320 inline __device__ void readQuads3d(const int elem, const int tidx,
321                                    const int tidy, const int comp,
322                                    const int dim, const int nelem, const CeedScalar *d_U, CeedScalar *r_U) {
323   for (int i = 0; i < Q1D; i++)
324     r_U[i] = d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
325                  comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D];
326 }
327 
328 inline __device__ void writeDofs3d(const int elem, const int tidx,
329                                    const int tidy, const int comp,
330                                    const int nelem, const CeedScalar *r_V, CeedScalar *d_V) {
331   if (tidx<P1D && tidy<P1D) {
332     for (int i = 0; i < P1D; i++)
333       d_V[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
334           elem*BASIS_NCOMP*P1D*P1D*P1D ] = r_V[i];
335   }
336 }
337 
338 inline __device__ void writeQuads3d(const int elem, const int tidx,
339                                     const int tidy, const int comp,
340                                     const int dim, const int nelem, const CeedScalar *r_V, CeedScalar *d_V) {
341   for (int i = 0; i < Q1D; i++)
342     d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
343         dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D ] = r_V[i];
344 }
345 
346 inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
347                                    const int tidy, const int tidz,
348                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
349   for (int k = 0; k < P1D; ++k) {
350     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
351     __syncthreads();
352     V[k] = 0.0;
353     for (int i = 0; i < P1D; ++i) {
354       V[k] += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction
355     }
356     __syncthreads();
357   }
358 }
359 
360 inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
361                                    const int tidy, const int tidz,
362                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
363   for (int k = 0; k < P1D; ++k) {
364     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
365     __syncthreads();
366     V[k] = 0.0;
367     for (int i = 0; i < P1D; ++i) {
368       V[k] += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction
369     }
370     __syncthreads();
371   }
372 }
373 
374 inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
375                                    const int tidy, const int tidz,
376                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
377   for (int k = 0; k < Q1D; ++k) {
378     V[k] = 0.0;
379     for (int i = 0; i < P1D; ++i) {
380       V[k] += B[i + k*P1D] * U[i];//contract z direction
381     }
382   }
383 }
384 
385 inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
386     const int tidy, const int tidz,
387     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
388   for (int k = 0; k < Q1D; ++k) {
389     V[k] = 0.0;
390     if (k<P1D) {
391       for (int i = 0; i < Q1D; ++i) {
392         V[k] += B[k + i*P1D] * U[i];//contract z direction
393       }
394     }
395   }
396 }
397 
398 inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
399     const int tidy, const int tidz,
400     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
401   for (int k = 0; k < P1D; ++k) {
402     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
403     __syncthreads();
404     V[k] = 0.0;
405     if (tidy<P1D) {
406       for (int i = 0; i < Q1D; ++i) {
407         V[k] += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction
408       }
409     }
410     __syncthreads();
411   }
412 }
413 
414 inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
415     const int tidy, const int tidz,
416     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
417   for (int k = 0; k < P1D; ++k) {
418     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
419     __syncthreads();
420     V[k] = 0.0;
421     if (tidx<P1D) {
422       for (int i = 0; i < Q1D; ++i) {
423         V[k] += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction
424       }
425     }
426     __syncthreads();
427   }
428 }
429 
430 inline __device__ void interp3d(const CeedInt nelem, const int transpose,
431                                 const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
432                                 CeedScalar *__restrict__ d_V,
433                                 CeedScalar *slice) {
434   CeedScalar r_V[Q1D];
435   CeedScalar r_t[Q1D];
436 
437   const int tidx = threadIdx.x;
438   const int tidy = threadIdx.y;
439   const int tidz = threadIdx.z;
440   const int blockElem = tidz/BASIS_NCOMP;
441   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
442   const int comp = tidz%BASIS_NCOMP;
443 
444   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
445        elem += gridDim.x*elemsPerBlock) {
446     // for(int comp=0; comp<BASIS_NCOMP; comp++) {
447       for (int i = 0; i < Q1D; ++i) {
448         r_V[i] = 0.0;
449         r_t[i] = 0.0;
450       }
451       if(!transpose) {
452         readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
453         ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
454         ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
455         ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
456         writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
457       } else {
458         readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
459         ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
460         ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
461         ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
462         writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
463       }
464     // }
465   }
466 }
467 
468 inline __device__ void grad3d(const CeedInt nelem, const int transpose,
469                               const CeedScalar *c_B, const CeedScalar *c_G,
470                               const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V,
471                               CeedScalar *slice) {
472   //use P1D for one of these
473   CeedScalar r_U[Q1D];
474   CeedScalar r_V[Q1D];
475   CeedScalar r_t[Q1D];
476 
477   const int tidx = threadIdx.x;
478   const int tidy = threadIdx.y;
479   const int tidz = threadIdx.z;
480   const int blockElem = tidz/BASIS_NCOMP;
481   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
482   const int comp = tidz%BASIS_NCOMP;
483   int dim;
484 
485   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
486        elem += gridDim.x*elemsPerBlock) {
487     // for(int comp=0; comp<BASIS_NCOMP; comp++) {
488       if(!transpose) {
489         readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
490         ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
491         ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
492         ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
493         dim = 0;
494         writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
495         ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
496         ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
497         ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
498         dim = 1;
499         writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
500         ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
501         ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
502         ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
503         dim = 2;
504         writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
505       } else {
506         dim = 0;
507         readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
508         ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
509         ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
510         ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
511         dim = 1;
512         readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
513         ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
514         ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
515         ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
516         add(r_V, r_t);
517         dim = 2;
518         readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
519         ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
520         ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
521         ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
522         add(r_V, r_t);
523         writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
524       }
525     // }
526   }
527 }
528 
529 /////////////
530 // Kernels //
531 /////////////
532 extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
533                                   const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
534                                   CeedScalar *__restrict__ d_V) {
535   // __shared__ double slice[Q1D*Q1D];//Fix me if ElemPerBlock>1
536   extern __shared__ double slice[];
537   if (BASIS_DIM==1) {
538     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
539   } else if (BASIS_DIM==2) {
540     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
541   } else if (BASIS_DIM==3) {
542     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
543   }
544 }
545 
546 extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
547                                 const CeedScalar *c_B, const CeedScalar *c_G,
548                                 const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) {
549   // __shared__ double slice[Q1D*Q1D];//Fix me if ElemPerBlock>1
550   extern __shared__ double slice[];
551   if (BASIS_DIM==1) {
552     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
553   } else if (BASIS_DIM==2) {
554     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
555   } else if (BASIS_DIM==3) {
556     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
557   }
558 }
559 
560 /////////////
561 // Weights //
562 /////////////
563 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
564                          CeedScalar *w) {
565   const int tid = threadIdx.x;
566   const CeedScalar weight = qweight1d[tid];
567   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
568        elem += gridDim.x*blockDim.y) {
569     const int ind = elem*Q1D + tid;
570     w[ind] = weight;
571   }
572 }
573 
574 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
575                          CeedScalar *w) {
576   const int i = threadIdx.x;
577   const int j = threadIdx.y;
578   const CeedScalar weight = qweight1d[i]*qweight1d[j];
579   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
580        elem += gridDim.x*blockDim.z) {
581     const int ind = elem*Q1D*Q1D + i + j*Q1D;
582     w[ind] = weight;
583   }
584 }
585 
586 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
587                          CeedScalar *w) {
588   const int i = threadIdx.x;
589   const int j = threadIdx.y;
590   const int k = threadIdx.z;
591   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
592   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
593     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
594     w[ind] = weight;
595   }
596 }
597 
598 extern "C" __global__ void weight(const CeedInt nelem,
599                                   const CeedScalar *__restrict__ qweight1d, CeedScalar *__restrict__ v) {
600   if (BASIS_DIM==1) {
601     weight1d(nelem, qweight1d, v);
602   } else if (BASIS_DIM==2) {
603     weight2d(nelem, qweight1d, v);
604   } else if (BASIS_DIM==3) {
605     weight3d(nelem, qweight1d, v);
606   }
607 }
608 
609                                    );
610 
611 int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
612                        CeedScalar **c_B);
613 int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
614                            CeedInt Q1d, CeedScalar **c_B_ptr, CeedScalar **c_G_ptr);
615 
616 int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem,
617                                      CeedTransposeMode tmode,
618                                      CeedEvalMode emode, CeedVector u, CeedVector v) {
619   int ierr;
620   Ceed ceed;
621   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
622   Ceed_Cuda_shared *ceed_Cuda;
623   CeedGetData(ceed, (void *) &ceed_Cuda); CeedChk(ierr);
624   CeedBasis_Cuda_shared *data;
625   CeedBasisGetData(basis, (void *)&data); CeedChk(ierr);
626   const CeedInt transpose = tmode == CEED_TRANSPOSE;
627   CeedInt dim, ncomp;
628   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
629   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
630 
631   const CeedScalar *d_u;
632   CeedScalar *d_v;
633   if(emode!=CEED_EVAL_WEIGHT) {
634     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
635   }
636   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
637 
638   if (tmode == CEED_TRANSPOSE) {
639     CeedInt length;
640     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
641     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
642   }
643   if (emode == CEED_EVAL_INTERP) {
644     //TODO: check performance difference between c_B and d_B
645     CeedInt P1d, Q1d;
646     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
647     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
648     // if (ceed_Cuda->Q1d != Q1d || ceed_Cuda->P1d != P1d)
649     // {
650     //   ceed_Cuda->Q1d = Q1d;
651     //   ceed_Cuda->P1d = P1d;
652     //   ceed_Cuda->grad = false;
653       ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
654       CeedChk(ierr);
655     // }
656     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, &d_u, &d_v};
657     // void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->d_interp1d, &d_u, &d_v};
658     if (dim==1)
659     {
660       CeedInt elemsPerBlock = 32;
661       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
662       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
663       ierr = run_kernel_dim_shared(ceed, data->interp, grid, Q1d, 1, elemsPerBlock, sharedMem,
664                             interpargs);
665       CeedChk(ierr);
666     } else if (dim==2) {
667       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
668       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
669       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
670       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
671       ierr = run_kernel_dim_shared(ceed, data->interp, grid, Q1d, Q1d, ncomp*elemsPerBlock, sharedMem,
672                             interpargs);
673       CeedChk(ierr);
674     } else if (dim==3) {
675       CeedInt elemsPerBlock = 2;
676       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
677       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
678       ierr = run_kernel_dim_shared(ceed, data->interp, grid, Q1d, Q1d, ncomp*elemsPerBlock, sharedMem,
679                             interpargs);
680       CeedChk(ierr);
681     }
682   } else if (emode == CEED_EVAL_GRAD) {
683     CeedInt P1d, Q1d;
684     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
685     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
686     // if (ceed_Cuda->Q1d != Q1d || ceed_Cuda->P1d != P1d || !data->grad)
687     // {
688     //   ceed_Cuda->Q1d = Q1d;
689     //   ceed_Cuda->P1d = P1d;
690     //   ceed_Cuda->grad = true;
691       ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
692                                     Q1d, &data->c_B, &data->c_G);
693       CeedChk(ierr);
694     // }
695     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, &data->c_G, &d_u, &d_v};
696     // void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->d_interp1d, &data->d_grad1d, &d_u, &d_v};
697     if (dim==1)
698     {
699       CeedInt elemsPerBlock = 32;
700       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
701       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
702       ierr = run_kernel_dim_shared(ceed, data->grad, grid, Q1d, 1, elemsPerBlock, sharedMem,
703                           gradargs);
704       CeedChk(ierr);
705     } else if (dim==2) {
706       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
707       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
708       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
709       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
710       ierr = run_kernel_dim_shared(ceed, data->grad, grid, Q1d, Q1d, ncomp*elemsPerBlock, sharedMem,
711                           gradargs);
712       CeedChk(ierr);
713     } else if (dim==3) {
714       CeedInt elemsPerBlock = 2;
715       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
716       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
717       ierr = run_kernel_dim_shared(ceed, data->grad, grid, Q1d, Q1d, ncomp*elemsPerBlock, sharedMem,
718                           gradargs);
719       CeedChk(ierr);
720     }
721   } else if (emode == CEED_EVAL_WEIGHT) {
722     CeedInt Q1d;
723     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
724     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
725     if(dim==1){
726       const CeedInt elemsPerBlock = 32/Q1d;
727       const CeedInt gridsize = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
728       ierr = run_kernel_dim(ceed, data->weight, gridsize, Q1d, elemsPerBlock, 1, weightargs);
729     } else if(dim==2) {
730       const CeedInt optElems = 32/(Q1d*Q1d);
731       const CeedInt elemsPerBlock = optElems>0?optElems:1;
732       const CeedInt gridsize = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
733       ierr = run_kernel_dim(ceed, data->weight, gridsize, Q1d, Q1d, elemsPerBlock, weightargs);
734     } else if(dim==3) {
735       const CeedInt gridsize = nelem;
736       ierr = run_kernel_dim(ceed, data->weight, gridsize, Q1d, Q1d, Q1d, weightargs);
737     }
738   }
739 
740   if(emode!=CEED_EVAL_WEIGHT) {
741     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
742   }
743   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
744 
745   return 0;
746 }
747 
748 static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
749   int ierr;
750   Ceed ceed;
751   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
752 
753   CeedBasis_Cuda_shared *data;
754   ierr = CeedBasisGetData(basis, (void *) &data); CeedChk(ierr);
755 
756   CeedChk_Cu(ceed, cuModuleUnload(data->module));
757 
758   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr);
759   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr);
760   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr);
761 
762   ierr = CeedFree(&data); CeedChk(ierr);
763 
764   return 0;
765 }
766 
767 int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
768                                         const CeedScalar *interp1d,
769                                         const CeedScalar *grad1d,
770                                         const CeedScalar *qref1d,
771                                         const CeedScalar *qweight1d,
772                                         CeedBasis basis) {
773   int ierr;
774   Ceed ceed;
775   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
776   CeedBasis_Cuda_shared *data;
777   ierr = CeedCalloc(1, &data); CeedChk(ierr);
778 
779   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
780   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr);
781   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
782                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
783 
784   const CeedInt iBytes = qBytes * P1d;
785   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr);
786   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
787                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
788 
789   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr);
790   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
791                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
792 
793   CeedInt ncomp;
794   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
795   ierr = compile(ceed, kernelsShared, &data->module, 7,
796                  "Q1D", Q1d,
797                  "P1D", P1d,
798                  "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
799                      Q1d : P1d, dim),
800                  "BASIS_DIM", dim,
801                  "BASIS_NCOMP", ncomp,
802                  "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
803                  "BASIS_NQPT", CeedIntPow(Q1d, dim)
804                 ); CeedChk(ierr);
805   ierr = get_kernel(ceed, data->module, "interp", &data->interp);
806   CeedChk(ierr);
807   ierr = get_kernel(ceed, data->module, "grad", &data->grad);
808   CeedChk(ierr);
809   ierr = get_kernel(ceed, data->module, "weight", &data->weight);
810   CeedChk(ierr);
811 
812   ierr = CeedBasisSetData(basis, (void *)&data);
813   CeedChk(ierr);
814   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
815                                 CeedBasisApplyTensor_Cuda_shared);
816   CeedChk(ierr);
817   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
818                                 CeedBasisDestroy_Cuda_shared);
819   CeedChk(ierr);
820   return 0;
821 }
822 
823 int CeedBasisCreateH1_Cuda_shared(CeedElemTopology topo, CeedInt dim,
824                                   CeedInt ndof, CeedInt nqpts,
825                                   const CeedScalar *interp,
826                                   const CeedScalar *grad,
827                                   const CeedScalar *qref,
828                                   const CeedScalar *qweight,
829                                   CeedBasis basis) {
830   int ierr;
831   Ceed ceed;
832   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
833   return CeedError(ceed, 1, "Backend does not implement generic H1 basis");
834 }
835