xref: /libCEED/backends/cuda-shared/ceed-cuda-shared-basis.c (revision 80ac2e4337097b2baa4325022af48978c55bdefe)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed-backend.h>
18 #include <ceed.h>
19 #include "ceed-cuda-shared.h"
20 #include "../cuda/ceed-cuda.h"
21 
22 //------------------------------------------------------------------------------
23 // Shared mem kernels
24 //------------------------------------------------------------------------------
25 // *INDENT-OFF*
26 static const char *kernelsShared = QUOTE(
27 
28 //------------------------------------------------------------------------------
29 // Sum input into output
30 //------------------------------------------------------------------------------
31 inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
32   for (int i = 0; i < Q1D; i++)
33     r_V[i] += r_U[i];
34 }
35 
36 //------------------------------------------------------------------------------
37 // 1D
38 //------------------------------------------------------------------------------
39 
40 //------------------------------------------------------------------------------
41 // Read DoFs
42 //------------------------------------------------------------------------------
43 inline __device__ void readDofs1d(const int elem, const int tidx,
44                                   const int tidy, const int tidz,const int comp,
45                                   const int nelem, const CeedScalar *d_U,
46                                   CeedScalar *slice) {
47   for (int i = 0; i < P1D; i++)
48     slice[i + tidz*Q1D] = d_U[i + elem*P1D + comp*P1D*nelem];
49   for (int i = P1D; i < Q1D; i++)
50     slice[i + tidz*Q1D] = 0.0;
51 }
52 
53 //------------------------------------------------------------------------------
54 // Write DoFs
55 //------------------------------------------------------------------------------
56 inline __device__ void writeDofs1d(const int elem, const int tidx,
57                                    const int tidy, const int comp,
58                                    const int nelem, const CeedScalar &r_V,
59                                    CeedScalar *d_V) {
60   if (tidx<P1D)
61     d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V;
62 }
63 
64 //------------------------------------------------------------------------------
65 // Read quadrature point data
66 //------------------------------------------------------------------------------
67 inline __device__ void readQuads1d(const int elem, const int tidx,
68                                    const int tidy, const int tidz, const int comp,
69                                    const int dim, const int nelem,
70                                    const CeedScalar *d_U, CeedScalar *slice) {
71   for (int i = 0; i < Q1D; i++)
72     slice[i + tidz*Q1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
73                             dim*BASIS_NCOMP*nelem*Q1D];
74 }
75 
76 //------------------------------------------------------------------------------
77 // Write quadrature point data
78 //------------------------------------------------------------------------------
79 inline __device__ void writeQuads1d(const int elem, const int tidx,
80                                     const int tidy, const int comp,
81                                     const int dim, const int nelem,
82                                     const CeedScalar &r_V, CeedScalar *d_V) {
83   d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
84 }
85 
86 //------------------------------------------------------------------------------
87 // 1D tensor contraction
88 //------------------------------------------------------------------------------
89 inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
90                                    const int tidy, const int tidz,
91                                    const CeedScalar &U, const CeedScalar *B,
92                                    CeedScalar &V) {
93   V = 0.0;
94   for (int i = 0; i < P1D; ++i)
95     V += B[i + tidx*P1D] * slice[i + tidz*Q1D]; // Contract x direction
96 }
97 
98 //------------------------------------------------------------------------------
99 // 1D transpose tensor contraction
100 //------------------------------------------------------------------------------
101 inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
102     const int tidy, const int tidz,
103     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
104   V = 0.0;
105   for (int i = 0; i < Q1D; ++i)
106     V += B[tidx + i*P1D] * slice[i + tidz*Q1D]; // Contract x direction
107 }
108 
109 //------------------------------------------------------------------------------
110 // 1D interpolate to quadrature points
111 //------------------------------------------------------------------------------
112 inline __device__ void interp1d(const CeedInt nelem, const int transpose,
113                                 const CeedScalar *c_B,
114                                 const CeedScalar *__restrict__ d_U,
115                                 CeedScalar *__restrict__ d_V,
116                                 CeedScalar *slice) {
117   CeedScalar r_V;
118   CeedScalar r_t;
119 
120   const int tidx = threadIdx.x;
121   const int tidy = threadIdx.y;
122   const int tidz = threadIdx.z;
123 
124 
125   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
126        elem += gridDim.x*blockDim.z) {
127     for (int comp = 0; comp < BASIS_NCOMP; comp++) {
128       if (!transpose) {
129         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
130         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
131         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
132       } else {
133         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
134         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
135         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
136       }
137     }
138   }
139 }
140 
141 //------------------------------------------------------------------------------
142 // 1D derivatives at quadrature points
143 //------------------------------------------------------------------------------
144 inline __device__ void grad1d(const CeedInt nelem, const int transpose,
145                               const CeedScalar *c_B, const CeedScalar *c_G,
146                               const CeedScalar *__restrict__ d_U,
147                               CeedScalar *__restrict__ d_V,
148                               CeedScalar *slice) {
149   CeedScalar r_U;
150   CeedScalar r_V;
151 
152   const int tidx = threadIdx.x;
153   const int tidy = threadIdx.y;
154   const int tidz = threadIdx.z;
155   int dim;
156 
157   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
158        elem += gridDim.x*blockDim.z) {
159     for(int comp = 0; comp < BASIS_NCOMP; comp++) {
160       if (!transpose) {
161         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
162         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
163         dim = 0;
164         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
165       } else {
166         dim = 0;
167         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
168         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
169         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
170       }
171     }
172   }
173 }
174 
175 //------------------------------------------------------------------------------
176 // 1D Quadrature weights
177 //------------------------------------------------------------------------------
178 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
179                          CeedScalar *w) {
180   const int tid = threadIdx.x;
181   const CeedScalar weight = qweight1d[tid];
182   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
183        elem += gridDim.x*blockDim.y) {
184     const int ind = elem*Q1D + tid;
185     w[ind] = weight;
186   }
187 }
188 
189 //------------------------------------------------------------------------------
190 // 2D
191 //------------------------------------------------------------------------------
192 
193 //------------------------------------------------------------------------------
194 // Read DoFs
195 //------------------------------------------------------------------------------
196 inline __device__ void readDofs2d(const int elem, const int tidx,
197                                   const int tidy, const int comp,
198                                   const int nelem, const CeedScalar *d_U,
199                                   CeedScalar &U) {
200   U = (tidx<P1D && tidy<P1D) ?
201       d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0;
202 }
203 
204 //------------------------------------------------------------------------------
205 // Write DoFs
206 //------------------------------------------------------------------------------
207 inline __device__ void writeDofs2d(const int elem, const int tidx,
208                                    const int tidy, const int comp,
209                                    const int nelem, const CeedScalar &r_V,
210                                    CeedScalar *d_V) {
211   if (tidx<P1D && tidy<P1D)
212     d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V;
213 }
214 
215 //------------------------------------------------------------------------------
216 // Read quadrature point data
217 //------------------------------------------------------------------------------
218 inline __device__ void readQuads2d(const int elem, const int tidx,
219                                    const int tidy, const int comp,
220                                    const int dim, const int nelem,
221                                    const CeedScalar *d_U, CeedScalar &U ) {
222   U = d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
223                dim*BASIS_NCOMP*nelem*Q1D*Q1D];
224 }
225 
226 //------------------------------------------------------------------------------
227 // Write quadrature point data
228 //------------------------------------------------------------------------------
229 inline __device__ void writeQuads2d(const int elem, const int tidx,
230                                     const int tidy, const int comp,
231                                     const int dim, const int nelem,
232                                     const CeedScalar &r_V, CeedScalar *d_V) {
233   d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
234            dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V;
235 }
236 
237 //------------------------------------------------------------------------------
238 // 2D tensor contraction x
239 //------------------------------------------------------------------------------
240 inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
241                                    const int tidy, const int tidz,
242                                    const CeedScalar &U, const CeedScalar *B,
243                                    CeedScalar &V) {
244   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
245   __syncthreads();
246   V = 0.0;
247   for (int i = 0; i < P1D; ++i)
248     V += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
249   __syncthreads();
250 }
251 
252 //------------------------------------------------------------------------------
253 // 2D tensor contraction y
254 //------------------------------------------------------------------------------
255 inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
256                                    const int tidy, const int tidz,
257                                    const CeedScalar &U, const CeedScalar *B,
258                                    CeedScalar &V) {
259   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
260   __syncthreads();
261   V = 0.0;
262   for (int i = 0; i < P1D; ++i)
263     V += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
264   __syncthreads();
265 }
266 
267 //------------------------------------------------------------------------------
268 // 2D transpose tensor contraction y
269 //------------------------------------------------------------------------------
270 inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
271     const int tidy, const int tidz,
272     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
273   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
274   __syncthreads();
275   V = 0.0;
276   if (tidy < P1D)
277     for (int i = 0; i < Q1D; ++i)
278       V += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
279   __syncthreads();
280 }
281 
282 //------------------------------------------------------------------------------
283 // 2D transpose tensor contraction x
284 //------------------------------------------------------------------------------
285 inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
286     const int tidy, const int tidz,
287     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
288   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
289   __syncthreads();
290   V = 0.0;
291   if (tidx < P1D)
292     for (int i = 0; i < Q1D; ++i)
293       V += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
294   __syncthreads();
295 }
296 
297 //------------------------------------------------------------------------------
298 // 2D interpolate to quadrature points
299 //------------------------------------------------------------------------------
300 inline __device__ void interp2d(const CeedInt nelem, const int transpose,
301                                 const CeedScalar *c_B,
302                                 const CeedScalar *__restrict__ d_U,
303                                 CeedScalar *__restrict__ d_V,
304                                 CeedScalar *slice) {
305   CeedScalar r_V;
306   CeedScalar r_t;
307 
308   const int tidx = threadIdx.x;
309   const int tidy = threadIdx.y;
310   const int tidz = threadIdx.z;
311   const int blockElem = tidz/BASIS_NCOMP;
312   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
313   const int comp = tidz%BASIS_NCOMP;
314 
315   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
316        elem += gridDim.x*elemsPerBlock) {
317     const int comp = tidz%BASIS_NCOMP;
318     r_V = 0.0;
319     r_t = 0.0;
320     if (!transpose) {
321       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
322       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
323       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
324       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
325     } else {
326       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
327       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
328       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
329       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
330     }
331   }
332 }
333 
334 //------------------------------------------------------------------------------
335 // 2D derivatives at quadrature points
336 //------------------------------------------------------------------------------
337 inline __device__ void grad2d(const CeedInt nelem, const int transpose,
338                               const CeedScalar *c_B, const CeedScalar *c_G,
339                               const CeedScalar *__restrict__ d_U,
340                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
341   CeedScalar r_U;
342   CeedScalar r_V;
343   CeedScalar r_t;
344 
345   const int tidx = threadIdx.x;
346   const int tidy = threadIdx.y;
347   const int tidz = threadIdx.z;
348   const int blockElem = tidz/BASIS_NCOMP;
349   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
350   const int comp = tidz%BASIS_NCOMP;
351   int dim;
352 
353   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
354        elem += gridDim.x*elemsPerBlock) {
355     if (!transpose) {
356       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
357       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
358       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
359       dim = 0;
360       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
361       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
362       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
363       dim = 1;
364       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
365     } else {
366       dim = 0;
367       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
368       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
369       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
370       dim = 1;
371       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
372       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
373       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
374       r_V += r_U;
375       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
376     }
377   }
378 }
379 
380 //------------------------------------------------------------------------------
381 // 2D quadrature weights
382 //------------------------------------------------------------------------------
383 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
384                          CeedScalar *w) {
385   const int i = threadIdx.x;
386   const int j = threadIdx.y;
387   const CeedScalar weight = qweight1d[i]*qweight1d[j];
388   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
389        elem += gridDim.x*blockDim.z) {
390     const int ind = elem*Q1D*Q1D + i + j*Q1D;
391     w[ind] = weight;
392   }
393 }
394 
395 //------------------------------------------------------------------------------
396 // 3D
397 //------------------------------------------------------------------------------
398 
399 //------------------------------------------------------------------------------
400 // Read DoFs
401 //------------------------------------------------------------------------------
402 inline __device__ void readDofs3d(const int elem, const int tidx,
403                                   const int tidy, const int comp,
404                                   const int nelem, const CeedScalar *d_U,
405                                   CeedScalar *r_U) {
406   for (int i = 0; i < P1D; i++)
407     r_U[i] = (tidx < P1D && tidy < P1D) ?
408               d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
409                   comp*P1D*P1D*P1D*nelem] : 0.0;
410   for (int i = P1D; i < Q1D; i++)
411     r_U[i] = 0.0;
412 }
413 
414 //------------------------------------------------------------------------------
415 // Write DoFs
416 //------------------------------------------------------------------------------
417 inline __device__ void writeDofs3d(const int elem, const int tidx,
418                                    const int tidy, const int comp,
419                                    const int nelem, const CeedScalar *r_V,
420                                    CeedScalar *d_V) {
421   if (tidx < P1D && tidy < P1D) {
422     for (int i = 0; i < P1D; i++)
423       d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
424           comp*P1D*P1D*P1D*nelem] = r_V[i];
425   }
426 }
427 
428 //------------------------------------------------------------------------------
429 // Read quadrature point data
430 //------------------------------------------------------------------------------
431 inline __device__ void readQuads3d(const int elem, const int tidx,
432                                    const int tidy, const int comp,
433                                    const int dim, const int nelem,
434                                    const CeedScalar *d_U, CeedScalar *r_U) {
435   for (int i = 0; i < Q1D; i++)
436     r_U[i] = d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
437                  comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D];
438 }
439 
440 //------------------------------------------------------------------------------
441 // Write quadrature point data
442 //------------------------------------------------------------------------------
443 inline __device__ void writeQuads3d(const int elem, const int tidx,
444                                     const int tidy, const int comp,
445                                     const int dim, const int nelem,
446                                     const CeedScalar *r_V, CeedScalar *d_V) {
447   for (int i = 0; i < Q1D; i++)
448     d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
449         dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i];
450 }
451 
452 //------------------------------------------------------------------------------
453 // 3D tensor contract x
454 //------------------------------------------------------------------------------
455 inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
456                                    const int tidy, const int tidz,
457                                    const CeedScalar *U, const CeedScalar *B,
458                                    CeedScalar *V) {
459   for (int k = 0; k < P1D; ++k) {
460     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
461     __syncthreads();
462     V[k] = 0.0;
463     for (int i = 0; i < P1D; ++i)
464       V[k] += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
465     __syncthreads();
466   }
467 }
468 
469 //------------------------------------------------------------------------------
470 // 3D tensor contract y
471 //------------------------------------------------------------------------------
472 inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
473                                    const int tidy, const int tidz,
474                                    const CeedScalar *U, const CeedScalar *B,
475                                    CeedScalar *V) {
476   for (int k = 0; k < P1D; ++k) {
477     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
478     __syncthreads();
479     V[k] = 0.0;
480     for (int i = 0; i < P1D; ++i)
481       V[k] += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
482     __syncthreads();
483   }
484 }
485 
486 //------------------------------------------------------------------------------
487 // 3D tensor contract z
488 //------------------------------------------------------------------------------
489 inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
490                                    const int tidy, const int tidz,
491                                    const CeedScalar *U, const CeedScalar *B,
492                                    CeedScalar *V) {
493   for (int k = 0; k < Q1D; ++k) {
494     V[k] = 0.0;
495     for (int i = 0; i < P1D; ++i)
496       V[k] += B[i + k*P1D] * U[i]; // Contract z direction
497   }
498 }
499 
500 //------------------------------------------------------------------------------
501 // 3D transpose tensor contract z
502 //------------------------------------------------------------------------------
503 inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
504     const int tidy, const int tidz,
505     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
506   for (int k = 0; k < Q1D; ++k) {
507     V[k] = 0.0;
508     if (k < P1D)
509       for (int i = 0; i < Q1D; ++i)
510         V[k] += B[k + i*P1D] * U[i]; // Contract z direction
511   }
512 }
513 
514 //------------------------------------------------------------------------------
515 // 3D transpose tensor contract y
516 //------------------------------------------------------------------------------
517 inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
518     const int tidy, const int tidz,
519     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
520   for (int k = 0; k < P1D; ++k) {
521     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
522     __syncthreads();
523     V[k] = 0.0;
524     if (tidy < P1D)
525       for (int i = 0; i < Q1D; ++i)
526         V[k] += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
527     __syncthreads();
528   }
529 }
530 
531 //------------------------------------------------------------------------------
532 // 3D transpose tensor contract x
533 //------------------------------------------------------------------------------
534 inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
535     const int tidy, const int tidz,
536     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
537   for (int k = 0; k < P1D; ++k) {
538     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
539     __syncthreads();
540     V[k] = 0.0;
541     if (tidx < P1D)
542       for (int i = 0; i < Q1D; ++i)
543         V[k] += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
544     __syncthreads();
545   }
546 }
547 
548 //------------------------------------------------------------------------------
549 // 3D interpolate to quadrature points
550 //------------------------------------------------------------------------------
551 inline __device__ void interp3d(const CeedInt nelem, const int transpose,
552                                 const CeedScalar *c_B,
553                                 const CeedScalar *__restrict__ d_U,
554                                 CeedScalar *__restrict__ d_V,
555                                 CeedScalar *slice) {
556   CeedScalar r_V[Q1D];
557   CeedScalar r_t[Q1D];
558 
559   const int tidx = threadIdx.x;
560   const int tidy = threadIdx.y;
561   const int tidz = threadIdx.z;
562   const int blockElem = tidz/BASIS_NCOMP;
563   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
564   const int comp = tidz%BASIS_NCOMP;
565 
566   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
567        elem += gridDim.x*elemsPerBlock) {
568     for (int i = 0; i < Q1D; ++i) {
569       r_V[i] = 0.0;
570       r_t[i] = 0.0;
571     }
572     if (!transpose) {
573       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
574       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
575       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
576       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
577       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
578     } else {
579       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
580       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
581       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
582       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
583       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
584     }
585   }
586 }
587 
588 //------------------------------------------------------------------------------
589 // 3D derivatives at quadrature points
590 //------------------------------------------------------------------------------
591 inline __device__ void grad3d(const CeedInt nelem, const int transpose,
592                               const CeedScalar *c_B, const CeedScalar *c_G,
593                               const CeedScalar *__restrict__ d_U,
594                               CeedScalar *__restrict__ d_V,
595                               CeedScalar *slice) {
596   // Use P1D for one of these
597   CeedScalar r_U[Q1D];
598   CeedScalar r_V[Q1D];
599   CeedScalar r_t[Q1D];
600 
601   const int tidx = threadIdx.x;
602   const int tidy = threadIdx.y;
603   const int tidz = threadIdx.z;
604   const int blockElem = tidz/BASIS_NCOMP;
605   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
606   const int comp = tidz%BASIS_NCOMP;
607   int dim;
608 
609   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
610        elem += gridDim.x*elemsPerBlock) {
611     if (!transpose) {
612       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
613       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
614       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
615       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
616       dim = 0;
617       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
618       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
619       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
620       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
621       dim = 1;
622       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
623       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
624       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
625       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
626       dim = 2;
627       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
628     } else {
629       dim = 0;
630       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
631       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
632       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
633       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
634       dim = 1;
635       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
636       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
637       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
638       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
639       add(r_V, r_t);
640       dim = 2;
641       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
642       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
643       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
644       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
645       add(r_V, r_t);
646       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
647     }
648   }
649 }
650 
651 //------------------------------------------------------------------------------
652 // 3D quadrature weights
653 //------------------------------------------------------------------------------
654 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
655                          CeedScalar *w) {
656   const int i = threadIdx.x;
657   const int j = threadIdx.y;
658   const int k = threadIdx.z;
659   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
660   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
661     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
662     w[ind] = weight;
663   }
664 }
665 
666 
667 //------------------------------------------------------------------------------
668 // Basis kernels
669 //------------------------------------------------------------------------------
670 
671 //------------------------------------------------------------------------------
672 // Interp kernel by dim
673 //------------------------------------------------------------------------------
674 extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
675                                   const CeedScalar *c_B,
676                                   const CeedScalar *__restrict__ d_U,
677                                   CeedScalar *__restrict__ d_V) {
678   extern __shared__ double slice[];
679   if (BASIS_DIM == 1) {
680     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
681   } else if (BASIS_DIM == 2) {
682     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
683   } else if (BASIS_DIM == 3) {
684     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
685   }
686 }
687 
688 //------------------------------------------------------------------------------
689 // Grad kernel by dim
690 //------------------------------------------------------------------------------
691 extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
692                                 const CeedScalar *c_B, const CeedScalar *c_G,
693                                 const CeedScalar *__restrict__ d_U,
694                                 CeedScalar *__restrict__ d_V) {
695   extern __shared__ double slice[];
696   if (BASIS_DIM == 1) {
697     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
698   } else if (BASIS_DIM == 2) {
699     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
700   } else if (BASIS_DIM == 3) {
701     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
702   }
703 }
704 
705 //------------------------------------------------------------------------------
706 // Weight kernels by dim
707 //------------------------------------------------------------------------------
708 extern "C" __global__ void weight(const CeedInt nelem,
709                                   const CeedScalar *__restrict__ qweight1d,
710                                   CeedScalar *__restrict__ v) {
711   if (BASIS_DIM == 1) {
712     weight1d(nelem, qweight1d, v);
713   } else if (BASIS_DIM == 2) {
714     weight2d(nelem, qweight1d, v);
715   } else if (BASIS_DIM == 3) {
716     weight3d(nelem, qweight1d, v);
717   }
718 }
719 
720 );
721 // *INDENT-ON*
722 
723 //------------------------------------------------------------------------------
724 // Device initalization
725 //------------------------------------------------------------------------------
726 int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
727                        CeedScalar **c_B);
728 int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
729                            CeedInt Q1d, CeedScalar **c_B_ptr,
730                            CeedScalar **c_G_ptr);
731 
732 //------------------------------------------------------------------------------
733 // Apply basis
734 //------------------------------------------------------------------------------
735 int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem,
736                                      CeedTransposeMode tmode,
737                                      CeedEvalMode emode, CeedVector u,
738                                      CeedVector v) {
739   int ierr;
740   Ceed ceed;
741   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
742   Ceed_Cuda_shared *ceed_Cuda;
743   CeedGetData(ceed, (void *) &ceed_Cuda); CeedChk(ierr);
744   CeedBasis_Cuda_shared *data;
745   CeedBasisGetData(basis, (void *)&data); CeedChk(ierr);
746   const CeedInt transpose = tmode == CEED_TRANSPOSE;
747   CeedInt dim, ncomp;
748   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
749   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
750 
751   // Read vectors
752   const CeedScalar *d_u;
753   CeedScalar *d_v;
754   if (emode != CEED_EVAL_WEIGHT) {
755     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
756   }
757   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
758 
759   // Clear v for transpose mode
760   if (tmode == CEED_TRANSPOSE) {
761     CeedInt length;
762     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
763     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
764   }
765 
766   // Apply basis operation
767   switch (emode) {
768   case CEED_EVAL_INTERP: {
769     CeedInt P1d, Q1d;
770     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
771     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
772     ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
773     CeedChk(ierr);
774     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
775                           &d_u, &d_v
776                          };
777     if (dim == 1) {
778       CeedInt elemsPerBlock = 32;
779       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
780                                              ? 1 : 0 );
781       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
782       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, 1,
783                                         elemsPerBlock, sharedMem,
784                                         interpargs); CeedChk(ierr);
785     } else if (dim == 2) {
786       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
787       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
788       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
789                                              ? 1 : 0 );
790       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
791       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d,
792                                         ncomp*elemsPerBlock, sharedMem,
793                                         interpargs); CeedChk(ierr);
794     } else if (dim == 3) {
795       CeedInt elemsPerBlock = 1;
796       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
797                                              ? 1 : 0 );
798       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
799       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d,
800                                         ncomp*elemsPerBlock, sharedMem,
801                                         interpargs); CeedChk(ierr);
802     }
803   } break;
804   case CEED_EVAL_GRAD: {
805     CeedInt P1d, Q1d;
806     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
807     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
808     ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
809                                   Q1d, &data->c_B, &data->c_G);
810     CeedChk(ierr);
811     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
812                         &data->c_G, &d_u, &d_v
813                        };
814     if (dim == 1) {
815       CeedInt elemsPerBlock = 32;
816       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
817                                              ? 1 : 0 );
818       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
819       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, 1,
820                                         elemsPerBlock, sharedMem, gradargs);
821       CeedChk(ierr);
822     } else if (dim == 2) {
823       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
824       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
825       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
826                                              ? 1 : 0 );
827       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
828       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d,
829                                         ncomp*elemsPerBlock, sharedMem,
830                                         gradargs); CeedChk(ierr);
831     } else if (dim == 3) {
832       CeedInt elemsPerBlock = 1;
833       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
834                                              ? 1 : 0 );
835       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
836       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d,
837                                         ncomp*elemsPerBlock, sharedMem,
838                                         gradargs); CeedChk(ierr);
839     }
840   } break;
841   case CEED_EVAL_WEIGHT: {
842     CeedInt Q1d;
843     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
844     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
845     if (dim == 1) {
846       const CeedInt elemsPerBlock = 32/Q1d;
847       const CeedInt gridsize = nelem/elemsPerBlock + ( (
848                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
849       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d,
850                                   elemsPerBlock, 1, weightargs);
851       CeedChk(ierr);
852     } else if (dim == 2) {
853       const CeedInt optElems = 32/(Q1d*Q1d);
854       const CeedInt elemsPerBlock = optElems>0?optElems:1;
855       const CeedInt gridsize = nelem/elemsPerBlock + ( (
856                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
857       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d,
858                                   elemsPerBlock, weightargs);
859       CeedChk(ierr);
860     } else if (dim == 3) {
861       const CeedInt gridsize = nelem;
862       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
863                                   weightargs);
864       CeedChk(ierr);
865     }
866   } break;
867   // LCOV_EXCL_START
868   // Evaluate the divergence to/from the quadrature points
869   case CEED_EVAL_DIV:
870     return CeedError(ceed, 1, "CEED_EVAL_DIV not supported");
871   // Evaluate the curl to/from the quadrature points
872   case CEED_EVAL_CURL:
873     return CeedError(ceed, 1, "CEED_EVAL_CURL not supported");
874   // Take no action, BasisApply should not have been called
875   case CEED_EVAL_NONE:
876     return CeedError(ceed, 1,
877                      "CEED_EVAL_NONE does not make sense in this context");
878     // LCOV_EXCL_STOP
879   }
880 
881   // Restore vectors
882   if (emode != CEED_EVAL_WEIGHT) {
883     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
884   }
885   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
886   return 0;
887 }
888 
889 //------------------------------------------------------------------------------
890 // Destroy basis
891 //------------------------------------------------------------------------------
892 static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
893   int ierr;
894   Ceed ceed;
895   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
896 
897   CeedBasis_Cuda_shared *data;
898   ierr = CeedBasisGetData(basis, (void *) &data); CeedChk(ierr);
899 
900   CeedChk_Cu(ceed, cuModuleUnload(data->module));
901 
902   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr);
903   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr);
904   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr);
905 
906   ierr = CeedFree(&data); CeedChk(ierr);
907 
908   return 0;
909 }
910 
911 //------------------------------------------------------------------------------
912 // Create tensor basis
913 //------------------------------------------------------------------------------
914 int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
915                                         const CeedScalar *interp1d,
916                                         const CeedScalar *grad1d,
917                                         const CeedScalar *qref1d,
918                                         const CeedScalar *qweight1d,
919                                         CeedBasis basis) {
920   int ierr;
921   Ceed ceed;
922   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
923   if (Q1d<P1d) {
924     return CeedError(ceed, 1, "Backend does not implement underintegrated basis.");
925   }
926   CeedBasis_Cuda_shared *data;
927   ierr = CeedCalloc(1, &data); CeedChk(ierr);
928 
929   // Copy basis data to GPU
930   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
931   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr);
932   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
933                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
934 
935   const CeedInt iBytes = qBytes * P1d;
936   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr);
937   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
938                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
939 
940   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr);
941   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
942                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
943 
944   // Compute collocated gradient and copy to GPU
945   data->d_collograd1d = NULL;
946   if (dim == 3 && Q1d >= P1d) {
947     CeedScalar *collograd1d;
948     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChk(ierr);
949     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChk(ierr);
950     ierr = cudaMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
951     CeedChk_Cu(ceed, ierr);
952     ierr = cudaMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
953                       cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
954   }
955 
956   // Compile basis kernels
957   CeedInt ncomp;
958   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
959   ierr = CeedCompileCuda(ceed, kernelsShared, &data->module, 7,
960                          "Q1D", Q1d,
961                          "P1D", P1d,
962                          "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
963                              Q1d : P1d, dim),
964                          "BASIS_DIM", dim,
965                          "BASIS_NCOMP", ncomp,
966                          "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
967                          "BASIS_NQPT", CeedIntPow(Q1d, dim)
968                         ); CeedChk(ierr);
969   ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp);
970   CeedChk(ierr);
971   ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad);
972   CeedChk(ierr);
973   ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight);
974   CeedChk(ierr);
975 
976   ierr = CeedBasisSetData(basis, (void *)&data); CeedChk(ierr);
977 
978   // Register backend functions
979   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
980                                 CeedBasisApplyTensor_Cuda_shared);
981   CeedChk(ierr);
982   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
983                                 CeedBasisDestroy_Cuda_shared); CeedChk(ierr);
984   return 0;
985 }
986 //------------------------------------------------------------------------------
987