xref: /libCEED/backends/cuda-shared/ceed-cuda-shared-basis.c (revision 7af48cf9dcfc376eb44438498992419314d973a5)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed-backend.h>
18 #include <ceed.h>
19 #include "ceed-cuda-shared.h"
20 #include "../cuda/ceed-cuda.h"
21 
22 //*********************
23 // shared mem kernels
24 // *INDENT-OFF*
25 static const char *kernelsShared = QUOTE(
26 
27 inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
28   for (int i = 0; i < Q1D; i++)
29     r_V[i] += r_U[i];
30 }
31 
32 //////////
33 //  1D  //
34 //////////
35 
36 inline __device__ void readDofs1d(const int elem, const int tidx,
37                                   const int tidy, const int tidz,const int comp,
38                                   const int nelem, const CeedScalar *d_U,
39                                   CeedScalar *slice) {
40   for (int i = 0; i < P1D; i++)
41     slice[i+tidz*Q1D] = d_U[i + comp*P1D + elem*BASIS_NCOMP*P1D];
42   for (int i = P1D; i < Q1D; i++)
43     slice[i+tidz*Q1D] = 0.0;
44 }
45 
46 inline __device__ void writeDofs1d(const int elem, const int tidx,
47                                    const int tidy, const int comp,
48                                    const int nelem, const CeedScalar &r_V,
49                                    CeedScalar *d_V) {
50   if (tidx<P1D) {
51     d_V[tidx + comp*P1D + elem*BASIS_NCOMP*P1D] = r_V;
52   }
53 }
54 
55 inline __device__ void readQuads1d(const int elem, const int tidx,
56                                    const int tidy, const int tidz, const int comp,
57                                    const int dim, const int nelem,
58                                    const CeedScalar *d_U, CeedScalar *slice) {
59   for (int i = 0; i < Q1D; i++)
60     slice[i+tidz*Q1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
61                             dim*BASIS_NCOMP*nelem*Q1D];
62 }
63 
64 inline __device__ void writeQuads1d(const int elem, const int tidx,
65                                     const int tidy, const int comp,
66                                     const int dim, const int nelem,
67                                     const CeedScalar &r_V, CeedScalar *d_V) {
68   d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
69 }
70 
71 inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
72                                    const int tidy, const int tidz,
73                                    const CeedScalar &U, const CeedScalar *B,
74                                    CeedScalar &V) {
75   V = 0.0;
76   for (int i = 0; i < P1D; ++i) {
77     V += B[i + tidx*P1D] * slice[i+tidz*Q1D];//contract x direction
78   }
79 }
80 
81 inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
82     const int tidy, const int tidz,
83     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
84   V = 0.0;
85   for (int i = 0; i < Q1D; ++i) {
86     V += B[tidx + i*P1D] * slice[i+tidz*Q1D];//contract x direction
87   }
88 }
89 
90 inline __device__ void interp1d(const CeedInt nelem, const int transpose,
91                                 const CeedScalar *c_B,
92                                 const CeedScalar *__restrict__ d_U,
93                                 CeedScalar *__restrict__ d_V,
94                                 CeedScalar *slice) {
95   CeedScalar r_V;
96   CeedScalar r_t;
97 
98   const int tidx = threadIdx.x;
99   const int tidy = threadIdx.y;
100   const int tidz = threadIdx.z;
101 
102 
103   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
104        elem += gridDim.x*blockDim.z) {
105     for(int comp=0; comp<BASIS_NCOMP; comp++) {
106       if (!transpose) {
107         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
108         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
109         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
110       } else {
111         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
112         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
113         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
114       }
115     }
116   }
117 }
118 
119 inline __device__ void grad1d(const CeedInt nelem, const int transpose,
120                               const CeedScalar *c_B, const CeedScalar *c_G,
121                               const CeedScalar *__restrict__ d_U,
122                               CeedScalar *__restrict__ d_V,
123                               CeedScalar *slice) {
124   CeedScalar r_U;
125   CeedScalar r_V;
126 
127   const int tidx = threadIdx.x;
128   const int tidy = threadIdx.y;
129   const int tidz = threadIdx.z;
130   int dim;
131 
132   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
133        elem += gridDim.x*blockDim.z) {
134     for(int comp=0; comp<BASIS_NCOMP; comp++) {
135       if (!transpose) {
136         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
137         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
138         dim = 0;
139         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
140       } else {
141         dim = 0;
142         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
143         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
144         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
145       }
146     }
147   }
148 }
149 //////////
150 //  2D  //
151 //////////
152 
153 inline __device__ void readDofs2d(const int elem, const int tidx,
154                                   const int tidy, const int comp,
155                                   const int nelem, const CeedScalar *d_U,
156                                   CeedScalar &U) {
157   U = (tidx<P1D
158        && tidy<P1D) ? d_U[tidx + tidy*P1D + comp*P1D*P1D +
159                           elem*BASIS_NCOMP*P1D*P1D ] :
160       0.0;
161 }
162 
163 inline __device__ void writeDofs2d(const int elem, const int tidx,
164                                    const int tidy, const int comp,
165                                    const int nelem, const CeedScalar &r_V,
166                                    CeedScalar *d_V) {
167   if (tidx<P1D && tidy<P1D) {
168     d_V[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D ] = r_V;
169   }
170 }
171 
172 inline __device__ void readQuads2d(const int elem, const int tidx,
173                                    const int tidy, const int comp,
174                                    const int dim, const int nelem,
175                                    const CeedScalar *d_U, CeedScalar &U ) {
176   U = d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
177                dim*BASIS_NCOMP*nelem*Q1D*Q1D];
178 }
179 
180 inline __device__ void writeQuads2d(const int elem, const int tidx,
181                                     const int tidy, const int comp,
182                                     const int dim, const int nelem,
183                                     const CeedScalar &r_V, CeedScalar *d_V) {
184   d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
185            dim*BASIS_NCOMP*nelem*Q1D*Q1D ] = r_V;
186 }
187 
188 inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
189                                    const int tidy, const int tidz,
190                                    const CeedScalar &U, const CeedScalar *B,
191                                    CeedScalar &V) {
192   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
193   __syncthreads();
194   V = 0.0;
195   for (int i = 0; i < P1D; ++i) {
196     V += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction
197   }
198   __syncthreads();
199 }
200 
201 inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
202                                    const int tidy, const int tidz,
203                                    const CeedScalar &U, const CeedScalar *B,
204                                    CeedScalar &V) {
205   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
206   __syncthreads();
207   V = 0.0;
208   for (int i = 0; i < P1D; ++i) {
209     V += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction
210   }
211   __syncthreads();
212 }
213 
214 inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
215     const int tidy, const int tidz,
216     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
217   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
218   __syncthreads();
219   V = 0.0;
220   if (tidy<P1D) {
221     for (int i = 0; i < Q1D; ++i) {
222       V += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction
223     }
224   }
225   __syncthreads();
226 }
227 
228 inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
229     const int tidy, const int tidz,
230     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
231   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
232   __syncthreads();
233   V = 0.0;
234   if (tidx<P1D) {
235     for (int i = 0; i < Q1D; ++i) {
236       V += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction
237     }
238   }
239   __syncthreads();
240 }
241 
242 inline __device__ void interp2d(const CeedInt nelem, const int transpose,
243                                 const CeedScalar *c_B,
244                                 const CeedScalar *__restrict__ d_U,
245                                 CeedScalar *__restrict__ d_V,
246                                 CeedScalar *slice) {
247   CeedScalar r_V;
248   CeedScalar r_t;
249 
250   const int tidx = threadIdx.x;
251   const int tidy = threadIdx.y;
252   const int tidz = threadIdx.z;
253   const int blockElem = tidz/BASIS_NCOMP;
254   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
255   const int comp = tidz%BASIS_NCOMP;
256 
257   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
258        elem += gridDim.x*elemsPerBlock) {
259     const int comp = tidz%BASIS_NCOMP;
260     r_V = 0.0;
261     r_t = 0.0;
262     if (!transpose) {
263       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
264       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
265       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
266       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
267     } else {
268       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
269       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
270       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
271       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
272     }
273   }
274 }
275 
276 inline __device__ void grad2d(const CeedInt nelem, const int transpose,
277                               const CeedScalar *c_B, const CeedScalar *c_G,
278                               const CeedScalar *__restrict__ d_U,
279                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
280   CeedScalar r_U;
281   CeedScalar r_V;
282   CeedScalar r_t;
283 
284   const int tidx = threadIdx.x;
285   const int tidy = threadIdx.y;
286   const int tidz = threadIdx.z;
287   const int blockElem = tidz/BASIS_NCOMP;
288   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
289   const int comp = tidz%BASIS_NCOMP;
290   int dim;
291 
292   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
293        elem += gridDim.x*elemsPerBlock) {
294     if (!transpose) {
295       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
296       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
297       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
298       dim = 0;
299       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
300       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
301       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
302       dim = 1;
303       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
304     } else {
305       dim = 0;
306       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
307       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
308       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
309       dim = 1;
310       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
311       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
312       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
313       r_V+=r_U;
314       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
315     }
316   }
317 }
318 //////////
319 //  3D  //
320 //////////
321 
322 inline __device__ void readDofs3d(const int elem, const int tidx,
323                                   const int tidy, const int comp,
324                                   const int nelem, const CeedScalar *d_U,
325                                   CeedScalar *r_U) {
326   for (int i = 0; i < P1D; i++)
327     r_U[i] = (tidx<P1D
328               && tidy<P1D) ? d_U[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
329                                       elem*BASIS_NCOMP*P1D*P1D*P1D ] : 0.0;
330   for (int i = P1D; i < Q1D; i++)
331     r_U[i] = 0.0;
332 }
333 
334 inline __device__ void readQuads3d(const int elem, const int tidx,
335                                    const int tidy, const int comp,
336                                    const int dim, const int nelem,
337                                    const CeedScalar *d_U, CeedScalar *r_U) {
338   for (int i = 0; i < Q1D; i++)
339     r_U[i] = d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
340                  comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D];
341 }
342 
343 inline __device__ void writeDofs3d(const int elem, const int tidx,
344                                    const int tidy, const int comp,
345                                    const int nelem, const CeedScalar *r_V,
346                                    CeedScalar *d_V) {
347   if (tidx<P1D && tidy<P1D) {
348     for (int i = 0; i < P1D; i++)
349       d_V[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
350           elem*BASIS_NCOMP*P1D*P1D*P1D ] = r_V[i];
351   }
352 }
353 
354 inline __device__ void writeQuads3d(const int elem, const int tidx,
355                                     const int tidy, const int comp,
356                                     const int dim, const int nelem,
357                                     const CeedScalar *r_V, CeedScalar *d_V) {
358   for (int i = 0; i < Q1D; i++)
359     d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
360         dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D ] = r_V[i];
361 }
362 
363 inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
364                                    const int tidy, const int tidz,
365                                    const CeedScalar *U, const CeedScalar *B,
366                                    CeedScalar *V) {
367   for (int k = 0; k < P1D; ++k) {
368     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
369     __syncthreads();
370     V[k] = 0.0;
371     for (int i = 0; i < P1D; ++i) {
372       V[k] += B[i + tidx*P1D] * slice[i + tidy*Q1D +
373                                       tidz*Q1D*Q1D];//contract x direction
374     }
375     __syncthreads();
376   }
377 }
378 
379 inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
380                                    const int tidy, const int tidz,
381                                    const CeedScalar *U, const CeedScalar *B,
382                                    CeedScalar *V) {
383   for (int k = 0; k < P1D; ++k) {
384     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
385     __syncthreads();
386     V[k] = 0.0;
387     for (int i = 0; i < P1D; ++i) {
388       V[k] += B[i + tidy*P1D] * slice[tidx + i*Q1D +
389                                       tidz*Q1D*Q1D];//contract y direction
390     }
391     __syncthreads();
392   }
393 }
394 
395 inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
396                                    const int tidy, const int tidz,
397                                    const CeedScalar *U, const CeedScalar *B,
398                                    CeedScalar *V) {
399   for (int k = 0; k < Q1D; ++k) {
400     V[k] = 0.0;
401     for (int i = 0; i < P1D; ++i) {
402       V[k] += B[i + k*P1D] * U[i];//contract z direction
403     }
404   }
405 }
406 
407 inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
408     const int tidy, const int tidz,
409     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
410   for (int k = 0; k < Q1D; ++k) {
411     V[k] = 0.0;
412     if (k<P1D) {
413       for (int i = 0; i < Q1D; ++i) {
414         V[k] += B[k + i*P1D] * U[i];//contract z direction
415       }
416     }
417   }
418 }
419 
420 inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
421     const int tidy, const int tidz,
422     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
423   for (int k = 0; k < P1D; ++k) {
424     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
425     __syncthreads();
426     V[k] = 0.0;
427     if (tidy<P1D) {
428       for (int i = 0; i < Q1D; ++i) {
429         V[k] += B[tidy + i*P1D] * slice[tidx + i*Q1D +
430                                         tidz*Q1D*Q1D];//contract y direction
431       }
432     }
433     __syncthreads();
434   }
435 }
436 
437 inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
438     const int tidy, const int tidz,
439     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
440   for (int k = 0; k < P1D; ++k) {
441     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
442     __syncthreads();
443     V[k] = 0.0;
444     if (tidx<P1D) {
445       for (int i = 0; i < Q1D; ++i) {
446         V[k] += B[tidx + i*P1D] * slice[i + tidy*Q1D +
447                                         tidz*Q1D*Q1D];//contract x direction
448       }
449     }
450     __syncthreads();
451   }
452 }
453 
454 inline __device__ void interp3d(const CeedInt nelem, const int transpose,
455                                 const CeedScalar *c_B,
456                                 const CeedScalar *__restrict__ d_U,
457                                 CeedScalar *__restrict__ d_V,
458                                 CeedScalar *slice) {
459   CeedScalar r_V[Q1D];
460   CeedScalar r_t[Q1D];
461 
462   const int tidx = threadIdx.x;
463   const int tidy = threadIdx.y;
464   const int tidz = threadIdx.z;
465   const int blockElem = tidz/BASIS_NCOMP;
466   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
467   const int comp = tidz%BASIS_NCOMP;
468 
469   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
470        elem += gridDim.x*elemsPerBlock) {
471     for (int i = 0; i < Q1D; ++i) {
472       r_V[i] = 0.0;
473       r_t[i] = 0.0;
474     }
475     if (!transpose) {
476       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
477       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
478       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
479       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
480       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
481     } else {
482       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
483       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
484       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
485       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
486       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
487     }
488   }
489 }
490 
491 inline __device__ void grad3d(const CeedInt nelem, const int transpose,
492                               const CeedScalar *c_B, const CeedScalar *c_G,
493                               const CeedScalar *__restrict__ d_U,
494                               CeedScalar *__restrict__ d_V,
495                               CeedScalar *slice) {
496   //use P1D for one of these
497   CeedScalar r_U[Q1D];
498   CeedScalar r_V[Q1D];
499   CeedScalar r_t[Q1D];
500 
501   const int tidx = threadIdx.x;
502   const int tidy = threadIdx.y;
503   const int tidz = threadIdx.z;
504   const int blockElem = tidz/BASIS_NCOMP;
505   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
506   const int comp = tidz%BASIS_NCOMP;
507   int dim;
508 
509   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
510        elem += gridDim.x*elemsPerBlock) {
511     if (!transpose) {
512       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
513       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
514       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
515       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
516       dim = 0;
517       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
518       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
519       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
520       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
521       dim = 1;
522       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
523       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
524       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
525       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
526       dim = 2;
527       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
528     } else {
529       dim = 0;
530       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
531       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
532       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
533       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
534       dim = 1;
535       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
536       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
537       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
538       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
539       add(r_V, r_t);
540       dim = 2;
541       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
542       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
543       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
544       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
545       add(r_V, r_t);
546       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
547     }
548   }
549 }
550 
551 /////////////
552 // Kernels //
553 /////////////
554 extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
555                                   const CeedScalar *c_B,
556                                   const CeedScalar *__restrict__ d_U,
557                                   CeedScalar *__restrict__ d_V) {
558   extern __shared__ double slice[];
559   if (BASIS_DIM==1) {
560     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
561   } else if (BASIS_DIM==2) {
562     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
563   } else if (BASIS_DIM==3) {
564     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
565   }
566 }
567 
568 extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
569                                 const CeedScalar *c_B, const CeedScalar *c_G,
570                                 const CeedScalar *__restrict__ d_U,
571                                 CeedScalar *__restrict__ d_V) {
572   extern __shared__ double slice[];
573   if (BASIS_DIM==1) {
574     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
575   } else if (BASIS_DIM==2) {
576     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
577   } else if (BASIS_DIM==3) {
578     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
579   }
580 }
581 
582 /////////////
583 // Weights //
584 /////////////
585 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
586                          CeedScalar *w) {
587   const int tid = threadIdx.x;
588   const CeedScalar weight = qweight1d[tid];
589   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
590        elem += gridDim.x*blockDim.y) {
591     const int ind = elem*Q1D + tid;
592     w[ind] = weight;
593   }
594 }
595 
596 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
597                          CeedScalar *w) {
598   const int i = threadIdx.x;
599   const int j = threadIdx.y;
600   const CeedScalar weight = qweight1d[i]*qweight1d[j];
601   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
602        elem += gridDim.x*blockDim.z) {
603     const int ind = elem*Q1D*Q1D + i + j*Q1D;
604     w[ind] = weight;
605   }
606 }
607 
608 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
609                          CeedScalar *w) {
610   const int i = threadIdx.x;
611   const int j = threadIdx.y;
612   const int k = threadIdx.z;
613   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
614   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
615     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
616     w[ind] = weight;
617   }
618 }
619 
620 extern "C" __global__ void weight(const CeedInt nelem,
621                                   const CeedScalar *__restrict__ qweight1d,
622                                   CeedScalar *__restrict__ v) {
623   if (BASIS_DIM==1) {
624     weight1d(nelem, qweight1d, v);
625   } else if (BASIS_DIM==2) {
626     weight2d(nelem, qweight1d, v);
627   } else if (BASIS_DIM==3) {
628     weight3d(nelem, qweight1d, v);
629   }
630 }
631 
632 );
633 // *INDENT-ON*
634 
635 int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
636                        CeedScalar **c_B);
637 int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
638                            CeedInt Q1d, CeedScalar **c_B_ptr,
639                            CeedScalar **c_G_ptr);
640 
641 int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem,
642                                      CeedTransposeMode tmode,
643                                      CeedEvalMode emode, CeedVector u,
644                                      CeedVector v) {
645   int ierr;
646   Ceed ceed;
647   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
648   Ceed_Cuda_shared *ceed_Cuda;
649   CeedGetData(ceed, (void *) &ceed_Cuda); CeedChk(ierr);
650   CeedBasis_Cuda_shared *data;
651   CeedBasisGetData(basis, (void *)&data); CeedChk(ierr);
652   const CeedInt transpose = tmode == CEED_TRANSPOSE;
653   CeedInt dim, ncomp;
654   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
655   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
656 
657   const CeedScalar *d_u;
658   CeedScalar *d_v;
659   if (emode != CEED_EVAL_WEIGHT) {
660     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
661   }
662   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
663 
664   if (tmode == CEED_TRANSPOSE) {
665     CeedInt length;
666     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
667     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
668   }
669   if (emode == CEED_EVAL_INTERP) {
670     CeedInt P1d, Q1d;
671     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
672     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
673     ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
674     CeedChk(ierr);
675     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
676                           &d_u, &d_v
677                          };
678     if (dim==1) {
679       CeedInt elemsPerBlock = 32;
680       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
681                                              ? 1 : 0 );
682       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
683       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, 1,
684                                         elemsPerBlock, sharedMem,
685                                         interpargs);
686       CeedChk(ierr);
687     } else if (dim==2) {
688       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
689       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
690       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
691                                              ? 1 : 0 );
692       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
693       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d,
694                                         ncomp*elemsPerBlock, sharedMem,
695                                         interpargs);
696       CeedChk(ierr);
697     } else if (dim==3) {
698       CeedInt elemsPerBlock = 1;
699       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
700                                              ? 1 : 0 );
701       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
702       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d,
703                                         ncomp*elemsPerBlock, sharedMem,
704                                         interpargs);
705       CeedChk(ierr);
706     }
707   } else if (emode == CEED_EVAL_GRAD) {
708     CeedInt P1d, Q1d;
709     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
710     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
711     ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
712                                   Q1d, &data->c_B, &data->c_G);
713     CeedChk(ierr);
714     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
715                         &data->c_G, &d_u, &d_v
716                        };
717     if (dim==1) {
718       CeedInt elemsPerBlock = 32;
719       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
720                                              ? 1 : 0 );
721       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
722       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, 1, elemsPerBlock,
723                                         sharedMem,
724                                         gradargs);
725       CeedChk(ierr);
726     } else if (dim==2) {
727       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
728       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
729       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
730                                              ? 1 : 0 );
731       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
732       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d,
733                                         ncomp*elemsPerBlock, sharedMem,
734                                         gradargs);
735       CeedChk(ierr);
736     } else if (dim==3) {
737       CeedInt elemsPerBlock = 1;
738       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
739                                              ? 1 : 0 );
740       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
741       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d,
742                                         ncomp*elemsPerBlock, sharedMem,
743                                         gradargs);
744       CeedChk(ierr);
745     }
746   } else if (emode == CEED_EVAL_WEIGHT) {
747     CeedInt Q1d;
748     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
749     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
750     if (dim == 1) {
751       const CeedInt elemsPerBlock = 32/Q1d;
752       const CeedInt gridsize = nelem/elemsPerBlock + ( (
753                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
754       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d,
755                                   elemsPerBlock, 1, weightargs);
756       CeedChk(ierr);
757     } else if (dim == 2) {
758       const CeedInt optElems = 32/(Q1d*Q1d);
759       const CeedInt elemsPerBlock = optElems>0?optElems:1;
760       const CeedInt gridsize = nelem/elemsPerBlock + ( (
761                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
762       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d,
763                                   elemsPerBlock, weightargs);
764       CeedChk(ierr);
765     } else if (dim == 3) {
766       const CeedInt gridsize = nelem;
767       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
768                                   weightargs);
769       CeedChk(ierr);
770     }
771   }
772 
773   if (emode != CEED_EVAL_WEIGHT) {
774     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
775   }
776   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
777 
778   return 0;
779 }
780 
781 static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
782   int ierr;
783   Ceed ceed;
784   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
785 
786   CeedBasis_Cuda_shared *data;
787   ierr = CeedBasisGetData(basis, (void *) &data); CeedChk(ierr);
788 
789   CeedChk_Cu(ceed, cuModuleUnload(data->module));
790 
791   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr);
792   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr);
793   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr);
794 
795   ierr = CeedFree(&data); CeedChk(ierr);
796 
797   return 0;
798 }
799 
800 int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
801                                         const CeedScalar *interp1d,
802                                         const CeedScalar *grad1d,
803                                         const CeedScalar *qref1d,
804                                         const CeedScalar *qweight1d,
805                                         CeedBasis basis) {
806   int ierr;
807   Ceed ceed;
808   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
809   if (Q1d<P1d) {
810     return CeedError(ceed, 1, "Backend does not implement underintegrated basis.");
811   }
812   CeedBasis_Cuda_shared *data;
813   ierr = CeedCalloc(1, &data); CeedChk(ierr);
814 
815   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
816   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr);
817   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
818                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
819 
820   const CeedInt iBytes = qBytes * P1d;
821   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr);
822   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
823                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
824 
825   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr);
826   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
827                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
828 
829   data->d_collograd1d = NULL;
830   if (dim==3 && Q1d >= P1d) {
831     CeedScalar *collograd1d;
832     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChk(ierr);
833     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChk(ierr);
834     ierr = cudaMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
835     CeedChk_Cu(ceed, ierr);
836     ierr = cudaMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
837                       cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
838   }
839 
840   CeedInt ncomp;
841   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
842   ierr = CeedCompileCuda(ceed, kernelsShared, &data->module, 7,
843                          "Q1D", Q1d,
844                          "P1D", P1d,
845                          "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
846                              Q1d : P1d, dim),
847                          "BASIS_DIM", dim,
848                          "BASIS_NCOMP", ncomp,
849                          "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
850                          "BASIS_NQPT", CeedIntPow(Q1d, dim)
851                         ); CeedChk(ierr);
852   ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp);
853   CeedChk(ierr);
854   ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad);
855   CeedChk(ierr);
856   ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight);
857   CeedChk(ierr);
858 
859   ierr = CeedBasisSetData(basis, (void *)&data);
860   CeedChk(ierr);
861   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
862                                 CeedBasisApplyTensor_Cuda_shared);
863   CeedChk(ierr);
864   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
865                                 CeedBasisDestroy_Cuda_shared);
866   CeedChk(ierr);
867   return 0;
868 }
869