xref: /libCEED/backends/cuda-shared/ceed-cuda-shared-basis.c (revision ab85ce399a53bd0f1efd1c879a2c758f69243c90)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed/ceed.h>
18 #include <ceed/backend.h>
19 #include <cuda.h>
20 #include <cuda_runtime.h>
21 #include <stddef.h>
22 #include "ceed-cuda-shared.h"
23 #include "../cuda/ceed-cuda.h"
24 
25 //------------------------------------------------------------------------------
26 // Shared mem kernels
27 //------------------------------------------------------------------------------
28 // *INDENT-OFF*
29 static const char *kernelsShared = QUOTE(
30 
31 //------------------------------------------------------------------------------
32 // Sum input into output
33 //------------------------------------------------------------------------------
34 inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
35   for (int i = 0; i < P1D; i++)
36     r_V[i] += r_U[i];
37 }
38 
39 //------------------------------------------------------------------------------
40 // 1D
41 //------------------------------------------------------------------------------
42 
43 //------------------------------------------------------------------------------
44 // Read DoFs
45 //------------------------------------------------------------------------------
46 inline __device__ void readDofs1d(const int elem, const int tidx,
47                                   const int tidy, const int tidz,const int comp,
48                                   const int nelem, const CeedScalar *d_U,
49                                   CeedScalar *slice) {
50   for (int i = 0; i < P1D; i++)
51     slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem];
52   for (int i = P1D; i < Q1D; i++)
53     slice[i + tidz*T1D] = 0.0;
54 }
55 
56 //------------------------------------------------------------------------------
57 // Write DoFs
58 //------------------------------------------------------------------------------
59 inline __device__ void writeDofs1d(const int elem, const int tidx,
60                                    const int tidy, const int comp,
61                                    const int nelem, const CeedScalar &r_V,
62                                    CeedScalar *d_V) {
63   if (tidx<P1D)
64     d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V;
65 }
66 
67 //------------------------------------------------------------------------------
68 // Read quadrature point data
69 //------------------------------------------------------------------------------
70 inline __device__ void readQuads1d(const int elem, const int tidx,
71                                    const int tidy, const int tidz, const int comp,
72                                    const int dim, const int nelem,
73                                    const CeedScalar *d_U, CeedScalar *slice) {
74   for (int i = 0; i < Q1D; i++)
75     slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
76                             dim*BASIS_NCOMP*nelem*Q1D];
77   for (int i = Q1D; i < P1D; i++)
78     slice[i + tidz*T1D] = 0.0;
79 }
80 
81 //------------------------------------------------------------------------------
82 // Write quadrature point data
83 //------------------------------------------------------------------------------
84 inline __device__ void writeQuads1d(const int elem, const int tidx,
85                                     const int tidy, const int comp,
86                                     const int dim, const int nelem,
87                                     const CeedScalar &r_V, CeedScalar *d_V) {
88   if (tidx<Q1D)
89     d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
90 }
91 
92 //------------------------------------------------------------------------------
93 // 1D tensor contraction
94 //------------------------------------------------------------------------------
95 inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
96                                    const int tidy, const int tidz,
97                                    const CeedScalar &U, const CeedScalar *B,
98                                    CeedScalar &V) {
99   V = 0.0;
100   for (int i = 0; i < P1D; ++i)
101     V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction
102 }
103 
104 //------------------------------------------------------------------------------
105 // 1D transpose tensor contraction
106 //------------------------------------------------------------------------------
107 inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
108     const int tidy, const int tidz,
109     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
110   V = 0.0;
111   for (int i = 0; i < Q1D; ++i)
112     V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction
113 }
114 
115 //------------------------------------------------------------------------------
116 // 1D interpolate to quadrature points
117 //------------------------------------------------------------------------------
118 inline __device__ void interp1d(const CeedInt nelem, const int transpose,
119                                 const CeedScalar *c_B,
120                                 const CeedScalar *__restrict__ d_U,
121                                 CeedScalar *__restrict__ d_V,
122                                 CeedScalar *slice) {
123   CeedScalar r_V;
124   CeedScalar r_t;
125 
126   const int tidx = threadIdx.x;
127   const int tidy = threadIdx.y;
128   const int tidz = threadIdx.z;
129 
130   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
131        elem += gridDim.x*blockDim.z) {
132     for (int comp = 0; comp < BASIS_NCOMP; comp++) {
133       if (!transpose) {
134         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
135         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
136         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
137       } else {
138         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
139         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
140         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
141       }
142     }
143   }
144 }
145 
146 //------------------------------------------------------------------------------
147 // 1D derivatives at quadrature points
148 //------------------------------------------------------------------------------
149 inline __device__ void grad1d(const CeedInt nelem, const int transpose,
150                               const CeedScalar *c_B, const CeedScalar *c_G,
151                               const CeedScalar *__restrict__ d_U,
152                               CeedScalar *__restrict__ d_V,
153                               CeedScalar *slice) {
154   CeedScalar r_U;
155   CeedScalar r_V;
156 
157   const int tidx = threadIdx.x;
158   const int tidy = threadIdx.y;
159   const int tidz = threadIdx.z;
160   int dim;
161 
162   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
163        elem += gridDim.x*blockDim.z) {
164     for(int comp = 0; comp < BASIS_NCOMP; comp++) {
165       if (!transpose) {
166         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
167         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
168         dim = 0;
169         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
170       } else {
171         dim = 0;
172         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
173         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
174         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
175       }
176     }
177   }
178 }
179 
180 //------------------------------------------------------------------------------
181 // 1D Quadrature weights
182 //------------------------------------------------------------------------------
183 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
184                          CeedScalar *w) {
185   const int tid = threadIdx.x;
186   const CeedScalar weight = qweight1d[tid];
187   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
188        elem += gridDim.x*blockDim.y) {
189     const int ind = elem*Q1D + tid;
190     w[ind] = weight;
191   }
192 }
193 
194 //------------------------------------------------------------------------------
195 // 2D
196 //------------------------------------------------------------------------------
197 
198 //------------------------------------------------------------------------------
199 // Read DoFs
200 //------------------------------------------------------------------------------
201 inline __device__ void readDofs2d(const int elem, const int tidx,
202                                   const int tidy, const int comp,
203                                   const int nelem, const CeedScalar *d_U,
204                                   CeedScalar &U) {
205   U = (tidx<P1D && tidy<P1D) ?
206       d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0;
207 }
208 
209 //------------------------------------------------------------------------------
210 // Write DoFs
211 //------------------------------------------------------------------------------
212 inline __device__ void writeDofs2d(const int elem, const int tidx,
213                                    const int tidy, const int comp,
214                                    const int nelem, const CeedScalar &r_V,
215                                    CeedScalar *d_V) {
216   if (tidx<P1D && tidy<P1D)
217     d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V;
218 }
219 
220 //------------------------------------------------------------------------------
221 // Read quadrature point data
222 //------------------------------------------------------------------------------
223 inline __device__ void readQuads2d(const int elem, const int tidx,
224                                    const int tidy, const int comp,
225                                    const int dim, const int nelem,
226                                    const CeedScalar *d_U, CeedScalar &U ) {
227   U = (tidx<Q1D && tidy<Q1D) ?
228       d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
229       dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0;
230 }
231 
232 //------------------------------------------------------------------------------
233 // Write quadrature point data
234 //------------------------------------------------------------------------------
235 inline __device__ void writeQuads2d(const int elem, const int tidx,
236                                     const int tidy, const int comp,
237                                     const int dim, const int nelem,
238                                     const CeedScalar &r_V, CeedScalar *d_V) {
239   if (tidx<Q1D && tidy<Q1D)
240     d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
241     dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V;
242 }
243 
244 //------------------------------------------------------------------------------
245 // 2D tensor contraction x
246 //------------------------------------------------------------------------------
247 inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
248                                    const int tidy, const int tidz,
249                                    const CeedScalar &U, const CeedScalar *B,
250                                    CeedScalar &V) {
251   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
252   __syncthreads();
253   V = 0.0;
254   if (tidx < Q1D)
255     for (int i = 0; i < P1D; ++i)
256       V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
257   __syncthreads();
258 }
259 
260 //------------------------------------------------------------------------------
261 // 2D tensor contraction y
262 //------------------------------------------------------------------------------
263 inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
264                                    const int tidy, const int tidz,
265                                    const CeedScalar &U, const CeedScalar *B,
266                                    CeedScalar &V) {
267   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
268   __syncthreads();
269   V = 0.0;
270   if (tidy < Q1D)
271     for (int i = 0; i < P1D; ++i)
272       V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
273   __syncthreads();
274 }
275 
276 //------------------------------------------------------------------------------
277 // 2D transpose tensor contraction y
278 //------------------------------------------------------------------------------
279 inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
280     const int tidy, const int tidz,
281     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
282   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
283   __syncthreads();
284   V = 0.0;
285   if (tidy < P1D)
286     for (int i = 0; i < Q1D; ++i)
287       V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
288   __syncthreads();
289 }
290 
291 //------------------------------------------------------------------------------
292 // 2D transpose tensor contraction x
293 //------------------------------------------------------------------------------
294 inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
295     const int tidy, const int tidz,
296     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
297   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
298   __syncthreads();
299   V = 0.0;
300   if (tidx < P1D)
301     for (int i = 0; i < Q1D; ++i)
302       V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
303   __syncthreads();
304 }
305 
306 //------------------------------------------------------------------------------
307 // 2D interpolate to quadrature points
308 //------------------------------------------------------------------------------
309 inline __device__ void interp2d(const CeedInt nelem, const int transpose,
310                                 const CeedScalar *c_B,
311                                 const CeedScalar *__restrict__ d_U,
312                                 CeedScalar *__restrict__ d_V,
313                                 CeedScalar *slice) {
314   CeedScalar r_V;
315   CeedScalar r_t;
316 
317   const int tidx = threadIdx.x;
318   const int tidy = threadIdx.y;
319   const int tidz = threadIdx.z;
320   const int blockElem = tidz/BASIS_NCOMP;
321   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
322   const int comp = tidz%BASIS_NCOMP;
323 
324   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
325        elem += gridDim.x*elemsPerBlock) {
326     const int comp = tidz%BASIS_NCOMP;
327     r_V = 0.0;
328     r_t = 0.0;
329     if (!transpose) {
330       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
331       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
332       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
333       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
334     } else {
335       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
336       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
337       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
338       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
339     }
340   }
341 }
342 
343 //------------------------------------------------------------------------------
344 // 2D derivatives at quadrature points
345 //------------------------------------------------------------------------------
346 inline __device__ void grad2d(const CeedInt nelem, const int transpose,
347                               const CeedScalar *c_B, const CeedScalar *c_G,
348                               const CeedScalar *__restrict__ d_U,
349                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
350   CeedScalar r_U;
351   CeedScalar r_V;
352   CeedScalar r_t;
353 
354   const int tidx = threadIdx.x;
355   const int tidy = threadIdx.y;
356   const int tidz = threadIdx.z;
357   const int blockElem = tidz/BASIS_NCOMP;
358   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
359   const int comp = tidz%BASIS_NCOMP;
360   int dim;
361 
362   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
363        elem += gridDim.x*elemsPerBlock) {
364     if (!transpose) {
365       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
366       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
367       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
368       dim = 0;
369       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
370       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
371       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
372       dim = 1;
373       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
374     } else {
375       dim = 0;
376       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
377       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
378       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
379       dim = 1;
380       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
381       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
382       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
383       r_V += r_U;
384       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
385     }
386   }
387 }
388 
389 //------------------------------------------------------------------------------
390 // 2D quadrature weights
391 //------------------------------------------------------------------------------
392 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
393                          CeedScalar *w) {
394   const int i = threadIdx.x;
395   const int j = threadIdx.y;
396   const CeedScalar weight = qweight1d[i]*qweight1d[j];
397   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
398        elem += gridDim.x*blockDim.z) {
399     const int ind = elem*Q1D*Q1D + i + j*Q1D;
400     w[ind] = weight;
401   }
402 }
403 
404 //------------------------------------------------------------------------------
405 // 3D
406 //------------------------------------------------------------------------------
407 
408 //------------------------------------------------------------------------------
409 // Read DoFs
410 //------------------------------------------------------------------------------
411 inline __device__ void readDofs3d(const int elem, const int tidx,
412                                   const int tidy, const int comp,
413                                   const int nelem, const CeedScalar *d_U,
414                                   CeedScalar *r_U) {
415   for (int i = 0; i < P1D; i++)
416     r_U[i] = (tidx < P1D && tidy < P1D) ?
417               d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
418                   comp*P1D*P1D*P1D*nelem] : 0.0;
419   for (int i = P1D; i < Q1D; i++)
420     r_U[i] = 0.0;
421 }
422 
423 //------------------------------------------------------------------------------
424 // Write DoFs
425 //------------------------------------------------------------------------------
426 inline __device__ void writeDofs3d(const int elem, const int tidx,
427                                    const int tidy, const int comp,
428                                    const int nelem, const CeedScalar *r_V,
429                                    CeedScalar *d_V) {
430   if (tidx < P1D && tidy < P1D) {
431     for (int i = 0; i < P1D; i++)
432       d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
433           comp*P1D*P1D*P1D*nelem] = r_V[i];
434   }
435 }
436 
437 //------------------------------------------------------------------------------
438 // Read quadrature point data
439 //------------------------------------------------------------------------------
440 inline __device__ void readQuads3d(const int elem, const int tidx,
441                                    const int tidy, const int comp,
442                                    const int dim, const int nelem,
443                                    const CeedScalar *d_U, CeedScalar *r_U) {
444   for (int i = 0; i < Q1D; i++)
445     r_U[i] = (tidx < Q1D && tidy < Q1D) ?
446               d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
447               comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0;
448   for (int i = Q1D; i < P1D; i++)
449     r_U[i] = 0.0;
450 }
451 
452 //------------------------------------------------------------------------------
453 // Write quadrature point data
454 //------------------------------------------------------------------------------
455 inline __device__ void writeQuads3d(const int elem, const int tidx,
456                                     const int tidy, const int comp,
457                                     const int dim, const int nelem,
458                                     const CeedScalar *r_V, CeedScalar *d_V) {
459   if (tidx < Q1D && tidy < Q1D) {
460     for (int i = 0; i < Q1D; i++)
461       d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
462           dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i];
463   }
464 }
465 
466 //------------------------------------------------------------------------------
467 // 3D tensor contract x
468 //------------------------------------------------------------------------------
469 inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
470                                    const int tidy, const int tidz,
471                                    const CeedScalar *U,
472                                    const CeedScalar *B,
473                                    CeedScalar *V) {
474   for (int k = 0; k < P1D; ++k) {
475     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
476     __syncthreads();
477     V[k] = 0.0;
478     if (tidx < Q1D && tidy < P1D)
479       for (int i = 0; i < P1D; ++i)
480         V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
481     __syncthreads();
482   }
483 }
484 
485 //------------------------------------------------------------------------------
486 // 3D tensor contract y
487 //------------------------------------------------------------------------------
488 inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
489                                    const int tidy, const int tidz,
490                                    const CeedScalar *U,
491                                    const CeedScalar *B,
492                                    CeedScalar *V) {
493   for (int k = 0; k < P1D; ++k) {
494     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
495     __syncthreads();
496     V[k] = 0.0;
497     if (tidx < Q1D && tidy < Q1D)
498       for (int i = 0; i < P1D; ++i)
499         V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
500     __syncthreads();
501   }
502 }
503 
504 //------------------------------------------------------------------------------
505 // 3D tensor contract z
506 //------------------------------------------------------------------------------
507 inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
508                                    const int tidy, const int tidz,
509                                    const CeedScalar *U,
510                                    const CeedScalar *B,
511                                    CeedScalar *V) {
512   for (int k = 0; k < Q1D; ++k) {
513     V[k] = 0.0;
514     if (tidx < Q1D && tidy < Q1D)
515       for (int i = 0; i < P1D; ++i)
516         V[k] += B[i + k*P1D] * U[i]; // Contract z direction
517   }
518   for (int k = Q1D; k < P1D; ++k)
519     V[k] = 0.0;
520 }
521 
522 //------------------------------------------------------------------------------
523 // 3D transpose tensor contract z
524 //------------------------------------------------------------------------------
525 inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
526                                             const int tidy, const int tidz,
527                                             const CeedScalar *U,
528                                             const CeedScalar *B,
529                                             CeedScalar *V) {
530   for (int k = 0; k < P1D; ++k) {
531     V[k] = 0.0;
532     if (tidx < Q1D && tidy < Q1D)
533       for (int i = 0; i < Q1D; ++i)
534         V[k] += B[k + i*P1D] * U[i]; // Contract z direction
535   }
536   for (int k = P1D; k < Q1D; ++k)
537     V[k] = 0.0;
538 }
539 
540 //------------------------------------------------------------------------------
541 // 3D transpose tensor contract y
542 //------------------------------------------------------------------------------
543 inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
544                                             const int tidy, const int tidz,
545                                             const CeedScalar *U,
546                                             const CeedScalar *B,
547                                             CeedScalar *V) {
548   for (int k = 0; k < P1D; ++k) {
549     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
550     __syncthreads();
551     V[k] = 0.0;
552     if (tidx < Q1D && tidy < P1D)
553       for (int i = 0; i < Q1D; ++i)
554         V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
555     __syncthreads();
556   }
557 }
558 
559 //------------------------------------------------------------------------------
560 // 3D transpose tensor contract x
561 //------------------------------------------------------------------------------
562 inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
563                                             const int tidy, const int tidz,
564                                             const CeedScalar *U,
565                                             const CeedScalar *B,
566                                             CeedScalar *V) {
567   for (int k = 0; k < P1D; ++k) {
568     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
569     __syncthreads();
570     V[k] = 0.0;
571     if (tidx < P1D && tidy < P1D)
572       for (int i = 0; i < Q1D; ++i)
573         V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
574     __syncthreads();
575   }
576 }
577 
578 //------------------------------------------------------------------------------
579 // 3D interpolate to quadrature points
580 //------------------------------------------------------------------------------
581 inline __device__ void interp3d(const CeedInt nelem, const int transpose,
582                                 const CeedScalar *c_B,
583                                 const CeedScalar *__restrict__ d_U,
584                                 CeedScalar *__restrict__ d_V,
585                                 CeedScalar *slice) {
586   CeedScalar r_V[T1D];
587   CeedScalar r_t[T1D];
588 
589   const int tidx = threadIdx.x;
590   const int tidy = threadIdx.y;
591   const int tidz = threadIdx.z;
592   const int blockElem = tidz/BASIS_NCOMP;
593   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
594   const int comp = tidz%BASIS_NCOMP;
595 
596   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
597        elem += gridDim.x*elemsPerBlock) {
598     for (int i = 0; i < T1D; ++i) {
599       r_V[i] = 0.0;
600       r_t[i] = 0.0;
601     }
602     if (!transpose) {
603       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
604       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
605       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
606       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
607       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
608     } else {
609       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
610       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
611       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
612       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
613       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
614     }
615   }
616 }
617 
618 //------------------------------------------------------------------------------
619 // 3D derivatives at quadrature points
620 //------------------------------------------------------------------------------
621 inline __device__ void grad3d(const CeedInt nelem, const int transpose,
622                               const CeedScalar *c_B, const CeedScalar *c_G,
623                               const CeedScalar *__restrict__ d_U,
624                               CeedScalar *__restrict__ d_V,
625                               CeedScalar *slice) {
626   // Use P1D for one of these
627   CeedScalar r_U[T1D];
628   CeedScalar r_V[T1D];
629   CeedScalar r_t[T1D];
630 
631   const int tidx = threadIdx.x;
632   const int tidy = threadIdx.y;
633   const int tidz = threadIdx.z;
634   const int blockElem = tidz/BASIS_NCOMP;
635   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
636   const int comp = tidz%BASIS_NCOMP;
637   int dim;
638 
639   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
640        elem += gridDim.x*elemsPerBlock) {
641     for (int i = 0; i < T1D; ++i) {
642       r_U[i] = 0.0;
643       r_V[i] = 0.0;
644       r_t[i] = 0.0;
645     }
646     if (!transpose) {
647       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
648       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
649       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
650       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
651       dim = 0;
652       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
653       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
654       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
655       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
656       dim = 1;
657       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
658       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
659       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
660       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
661       dim = 2;
662       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
663     } else {
664       dim = 0;
665       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
666       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
667       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
668       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
669       dim = 1;
670       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
671       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
672       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
673       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
674       add(r_V, r_t);
675       dim = 2;
676       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
677       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
678       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
679       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
680       add(r_V, r_t);
681       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
682     }
683   }
684 }
685 
686 //------------------------------------------------------------------------------
687 // 3D quadrature weights
688 //------------------------------------------------------------------------------
689 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
690                          CeedScalar *w) {
691   const int i = threadIdx.x;
692   const int j = threadIdx.y;
693   const int k = threadIdx.z;
694   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
695   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
696     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
697     w[ind] = weight;
698   }
699 }
700 
701 //------------------------------------------------------------------------------
702 // Basis kernels
703 //------------------------------------------------------------------------------
704 
705 //------------------------------------------------------------------------------
706 // Interp kernel by dim
707 //------------------------------------------------------------------------------
708 extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
709                                   const CeedScalar *c_B,
710                                   const CeedScalar *__restrict__ d_U,
711                                   CeedScalar *__restrict__ d_V) {
712   extern __shared__ CeedScalar slice[];
713   if (BASIS_DIM == 1) {
714     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
715   } else if (BASIS_DIM == 2) {
716     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
717   } else if (BASIS_DIM == 3) {
718     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
719   }
720 }
721 
722 //------------------------------------------------------------------------------
723 // Grad kernel by dim
724 //------------------------------------------------------------------------------
725 extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
726                                 const CeedScalar *c_B, const CeedScalar *c_G,
727                                 const CeedScalar *__restrict__ d_U,
728                                 CeedScalar *__restrict__ d_V) {
729   extern __shared__ CeedScalar slice[];
730   if (BASIS_DIM == 1) {
731     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
732   } else if (BASIS_DIM == 2) {
733     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
734   } else if (BASIS_DIM == 3) {
735     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
736   }
737 }
738 
739 //------------------------------------------------------------------------------
740 // Weight kernels by dim
741 //------------------------------------------------------------------------------
742 extern "C" __global__ void weight(const CeedInt nelem,
743                                   const CeedScalar *__restrict__ qweight1d,
744                                   CeedScalar *__restrict__ v) {
745   if (BASIS_DIM == 1) {
746     weight1d(nelem, qweight1d, v);
747   } else if (BASIS_DIM == 2) {
748     weight2d(nelem, qweight1d, v);
749   } else if (BASIS_DIM == 3) {
750     weight3d(nelem, qweight1d, v);
751   }
752 }
753 
754 );
755 // *INDENT-ON*
756 
757 //------------------------------------------------------------------------------
758 // Device initalization
759 //------------------------------------------------------------------------------
760 int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
761                        CeedScalar **c_B);
762 int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
763                            CeedInt Q1d, CeedScalar **c_B_ptr,
764                            CeedScalar **c_G_ptr);
765 
766 //------------------------------------------------------------------------------
767 // Apply basis
768 //------------------------------------------------------------------------------
769 int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem,
770                                      CeedTransposeMode tmode,
771                                      CeedEvalMode emode, CeedVector u,
772                                      CeedVector v) {
773   int ierr;
774   Ceed ceed;
775   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
776   Ceed_Cuda *ceed_Cuda;
777   CeedGetData(ceed, &ceed_Cuda); CeedChkBackend(ierr);
778   CeedBasis_Cuda_shared *data;
779   CeedBasisGetData(basis, &data); CeedChkBackend(ierr);
780   const CeedInt transpose = tmode == CEED_TRANSPOSE;
781   CeedInt dim, ncomp;
782   ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr);
783   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr);
784 
785   // Read vectors
786   const CeedScalar *d_u;
787   CeedScalar *d_v;
788   if (emode != CEED_EVAL_WEIGHT) {
789     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChkBackend(ierr);
790   }
791   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr);
792 
793   // Clear v for transpose mode
794   if (tmode == CEED_TRANSPOSE) {
795     CeedInt length;
796     ierr = CeedVectorGetLength(v, &length); CeedChkBackend(ierr);
797     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChkBackend(ierr);
798   }
799 
800   // Apply basis operation
801   switch (emode) {
802   case CEED_EVAL_INTERP: {
803     CeedInt P1d, Q1d;
804     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr);
805     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr);
806     CeedInt thread1d = CeedIntMax(Q1d, P1d);
807     ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
808     CeedChkBackend(ierr);
809     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
810                           &d_u, &d_v
811                          };
812     if (dim == 1) {
813       CeedInt elemsPerBlock = 32;
814       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
815                                              ? 1 : 0 );
816       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
817       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, 1,
818                                         elemsPerBlock, sharedMem,
819                                         interpargs); CeedChkBackend(ierr);
820     } else if (dim == 2) {
821       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
822       // elemsPerBlock must be at least 1
823       CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1);
824       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
825                                              ? 1 : 0 );
826       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
827       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, thread1d,
828                                         ncomp*elemsPerBlock, sharedMem,
829                                         interpargs); CeedChkBackend(ierr);
830     } else if (dim == 3) {
831       CeedInt elemsPerBlock = 1;
832       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
833                                              ? 1 : 0 );
834       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
835       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, thread1d,
836                                         ncomp*elemsPerBlock, sharedMem,
837                                         interpargs); CeedChkBackend(ierr);
838     }
839   } break;
840   case CEED_EVAL_GRAD: {
841     CeedInt P1d, Q1d;
842     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr);
843     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr);
844     CeedInt thread1d = CeedIntMax(Q1d, P1d);
845     ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
846                                   Q1d, &data->c_B, &data->c_G);
847     CeedChkBackend(ierr);
848     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
849                         &data->c_G, &d_u, &d_v
850                        };
851     if (dim == 1) {
852       CeedInt elemsPerBlock = 32;
853       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
854                                              ? 1 : 0 );
855       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
856       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, 1,
857                                         elemsPerBlock, sharedMem, gradargs);
858       CeedChkBackend(ierr);
859     } else if (dim == 2) {
860       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
861       // elemsPerBlock must be at least 1
862       CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1);
863       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
864                                              ? 1 : 0 );
865       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
866       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, thread1d,
867                                         ncomp*elemsPerBlock, sharedMem,
868                                         gradargs); CeedChkBackend(ierr);
869     } else if (dim == 3) {
870       CeedInt elemsPerBlock = 1;
871       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
872                                              ? 1 : 0 );
873       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
874       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, thread1d,
875                                         ncomp*elemsPerBlock, sharedMem,
876                                         gradargs); CeedChkBackend(ierr);
877     }
878   } break;
879   case CEED_EVAL_WEIGHT: {
880     CeedInt Q1d;
881     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr);
882     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
883     if (dim == 1) {
884       const CeedInt elemsPerBlock = 32/Q1d;
885       const CeedInt gridsize = nelem/elemsPerBlock + ( (
886                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
887       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d,
888                                   elemsPerBlock, 1, weightargs);
889       CeedChkBackend(ierr);
890     } else if (dim == 2) {
891       const CeedInt optElems = 32/(Q1d*Q1d);
892       const CeedInt elemsPerBlock = optElems>0?optElems:1;
893       const CeedInt gridsize = nelem/elemsPerBlock + ( (
894                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
895       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d,
896                                   elemsPerBlock, weightargs);
897       CeedChkBackend(ierr);
898     } else if (dim == 3) {
899       const CeedInt gridsize = nelem;
900       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
901                                   weightargs);
902       CeedChkBackend(ierr);
903     }
904   } break;
905   // LCOV_EXCL_START
906   // Evaluate the divergence to/from the quadrature points
907   case CEED_EVAL_DIV:
908     return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
909   // Evaluate the curl to/from the quadrature points
910   case CEED_EVAL_CURL:
911     return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
912   // Take no action, BasisApply should not have been called
913   case CEED_EVAL_NONE:
914     return CeedError(ceed, CEED_ERROR_BACKEND,
915                      "CEED_EVAL_NONE does not make sense in this context");
916     // LCOV_EXCL_STOP
917   }
918 
919   // Restore vectors
920   if (emode != CEED_EVAL_WEIGHT) {
921     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChkBackend(ierr);
922   }
923   ierr = CeedVectorRestoreArray(v, &d_v); CeedChkBackend(ierr);
924   return CEED_ERROR_SUCCESS;
925 }
926 
927 //------------------------------------------------------------------------------
928 // Destroy basis
929 //------------------------------------------------------------------------------
930 static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
931   int ierr;
932   Ceed ceed;
933   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
934 
935   CeedBasis_Cuda_shared *data;
936   ierr = CeedBasisGetData(basis, &data); CeedChkBackend(ierr);
937 
938   CeedChk_Cu(ceed, cuModuleUnload(data->module));
939 
940   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr);
941   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr);
942   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr);
943   ierr = cudaFree(data->d_collograd1d); CeedChk_Cu(ceed, ierr);
944 
945   ierr = CeedFree(&data); CeedChkBackend(ierr);
946 
947   return CEED_ERROR_SUCCESS;
948 }
949 
950 //------------------------------------------------------------------------------
951 // Create tensor basis
952 //------------------------------------------------------------------------------
953 int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
954                                         const CeedScalar *interp1d,
955                                         const CeedScalar *grad1d,
956                                         const CeedScalar *qref1d,
957                                         const CeedScalar *qweight1d,
958                                         CeedBasis basis) {
959   int ierr;
960   Ceed ceed;
961   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
962   CeedBasis_Cuda_shared *data;
963   ierr = CeedCalloc(1, &data); CeedChkBackend(ierr);
964 
965   // Copy basis data to GPU
966   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
967   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr);
968   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
969                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
970 
971   const CeedInt iBytes = qBytes * P1d;
972   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr);
973   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
974                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
975 
976   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr);
977   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
978                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
979 
980   // Compute collocated gradient and copy to GPU
981   data->d_collograd1d = NULL;
982   if (dim == 3 && Q1d >= P1d) {
983     CeedScalar *collograd1d;
984     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChkBackend(ierr);
985     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChkBackend(ierr);
986     ierr = cudaMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
987     CeedChk_Cu(ceed, ierr);
988     ierr = cudaMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
989                       cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
990     ierr = CeedFree(&collograd1d); CeedChkBackend(ierr);
991   }
992 
993   // Compile basis kernels
994   CeedInt ncomp;
995   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr);
996   ierr = CeedCompileCuda(ceed, kernelsShared, &data->module, 8,
997                          "Q1D", Q1d,
998                          "P1D", P1d,
999                          "T1D", CeedIntMax(Q1d, P1d),
1000                          "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
1001                              Q1d : P1d, dim),
1002                          "BASIS_DIM", dim,
1003                          "BASIS_NCOMP", ncomp,
1004                          "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
1005                          "BASIS_NQPT", CeedIntPow(Q1d, dim)
1006                         ); CeedChkBackend(ierr);
1007   ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp);
1008   CeedChkBackend(ierr);
1009   ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad);
1010   CeedChkBackend(ierr);
1011   ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight);
1012   CeedChkBackend(ierr);
1013 
1014   ierr = CeedBasisSetData(basis, data); CeedChkBackend(ierr);
1015 
1016   // Register backend functions
1017   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
1018                                 CeedBasisApplyTensor_Cuda_shared);
1019   CeedChkBackend(ierr);
1020   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
1021                                 CeedBasisDestroy_Cuda_shared); CeedChkBackend(ierr);
1022   return CEED_ERROR_SUCCESS;
1023 }
1024 //------------------------------------------------------------------------------
1025