xref: /libCEED/backends/cuda-shared/ceed-cuda-shared-basis.c (revision bd33150addfd586a5f97816a8026f45b5a27ec2c)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed-backend.h>
18 #include <ceed.h>
19 #include "ceed-cuda-shared.h"
20 #include "../cuda/ceed-cuda.h"
21 
22 //------------------------------------------------------------------------------
23 // Shared mem kernels
24 //------------------------------------------------------------------------------
25 // *INDENT-OFF*
26 static const char *kernelsShared = QUOTE(
27 
28 //------------------------------------------------------------------------------
29 // Sum input into output
30 //------------------------------------------------------------------------------
31 inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
32   for (int i = 0; i < Q1D; i++)
33     r_V[i] += r_U[i];
34 }
35 
36 //------------------------------------------------------------------------------
37 // 1D
38 //------------------------------------------------------------------------------
39 
40 //------------------------------------------------------------------------------
41 // Read DoFs
42 //------------------------------------------------------------------------------
43 inline __device__ void readDofs1d(const int elem, const int tidx,
44                                   const int tidy, const int tidz,const int comp,
45                                   const int nelem, const CeedScalar *d_U,
46                                   CeedScalar *slice) {
47   for (int i = 0; i < P1D; i++)
48     slice[i + tidz*Q1D] = d_U[i + comp*P1D + elem*BASIS_NCOMP*P1D];
49   for (int i = P1D; i < Q1D; i++)
50     slice[i + tidz*Q1D] = 0.0;
51 }
52 
53 //------------------------------------------------------------------------------
54 // Write DoFs
55 //------------------------------------------------------------------------------
56 inline __device__ void writeDofs1d(const int elem, const int tidx,
57                                    const int tidy, const int comp,
58                                    const int nelem, const CeedScalar &r_V,
59                                    CeedScalar *d_V) {
60   if (tidx<P1D)
61     d_V[tidx + comp*P1D + elem*BASIS_NCOMP*P1D] = r_V;
62 }
63 
64 //------------------------------------------------------------------------------
65 // Read quadrature point data
66 //------------------------------------------------------------------------------
67 inline __device__ void readQuads1d(const int elem, const int tidx,
68                                    const int tidy, const int tidz, const int comp,
69                                    const int dim, const int nelem,
70                                    const CeedScalar *d_U, CeedScalar *slice) {
71   for (int i = 0; i < Q1D; i++)
72     slice[i + tidz*Q1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
73                             dim*BASIS_NCOMP*nelem*Q1D];
74 }
75 
76 //------------------------------------------------------------------------------
77 // Write quadrature point data
78 //------------------------------------------------------------------------------
79 inline __device__ void writeQuads1d(const int elem, const int tidx,
80                                     const int tidy, const int comp,
81                                     const int dim, const int nelem,
82                                     const CeedScalar &r_V, CeedScalar *d_V) {
83   d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
84 }
85 
86 //------------------------------------------------------------------------------
87 // 1D tensor contraction
88 //------------------------------------------------------------------------------
89 inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
90                                    const int tidy, const int tidz,
91                                    const CeedScalar &U, const CeedScalar *B,
92                                    CeedScalar &V) {
93   V = 0.0;
94   for (int i = 0; i < P1D; ++i)
95     V += B[i + tidx*P1D] * slice[i + tidz*Q1D]; // Contract x direction
96 }
97 
98 //------------------------------------------------------------------------------
99 // 1D transpose tensor contraction
100 //------------------------------------------------------------------------------
101 inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
102     const int tidy, const int tidz,
103     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
104   V = 0.0;
105   for (int i = 0; i < Q1D; ++i)
106     V += B[tidx + i*P1D] * slice[i + tidz*Q1D]; // Contract x direction
107 }
108 
109 //------------------------------------------------------------------------------
110 // 1D interpolate to quadrature points
111 //------------------------------------------------------------------------------
112 inline __device__ void interp1d(const CeedInt nelem, const int transpose,
113                                 const CeedScalar *c_B,
114                                 const CeedScalar *__restrict__ d_U,
115                                 CeedScalar *__restrict__ d_V,
116                                 CeedScalar *slice) {
117   CeedScalar r_V;
118   CeedScalar r_t;
119 
120   const int tidx = threadIdx.x;
121   const int tidy = threadIdx.y;
122   const int tidz = threadIdx.z;
123 
124 
125   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
126        elem += gridDim.x*blockDim.z) {
127     for (int comp = 0; comp < BASIS_NCOMP; comp++) {
128       if (!transpose) {
129         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
130         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
131         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
132       } else {
133         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
134         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
135         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
136       }
137     }
138   }
139 }
140 
141 //------------------------------------------------------------------------------
142 // 1D derivatives at quadrature points
143 //------------------------------------------------------------------------------
144 inline __device__ void grad1d(const CeedInt nelem, const int transpose,
145                               const CeedScalar *c_B, const CeedScalar *c_G,
146                               const CeedScalar *__restrict__ d_U,
147                               CeedScalar *__restrict__ d_V,
148                               CeedScalar *slice) {
149   CeedScalar r_U;
150   CeedScalar r_V;
151 
152   const int tidx = threadIdx.x;
153   const int tidy = threadIdx.y;
154   const int tidz = threadIdx.z;
155   int dim;
156 
157   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
158        elem += gridDim.x*blockDim.z) {
159     for(int comp = 0; comp < BASIS_NCOMP; comp++) {
160       if (!transpose) {
161         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
162         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
163         dim = 0;
164         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
165       } else {
166         dim = 0;
167         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
168         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
169         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
170       }
171     }
172   }
173 }
174 
175 //------------------------------------------------------------------------------
176 // 1D Quadrature weights
177 //------------------------------------------------------------------------------
178 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
179                          CeedScalar *w) {
180   const int tid = threadIdx.x;
181   const CeedScalar weight = qweight1d[tid];
182   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
183        elem += gridDim.x*blockDim.y) {
184     const int ind = elem*Q1D + tid;
185     w[ind] = weight;
186   }
187 }
188 
189 //------------------------------------------------------------------------------
190 // 2D
191 //------------------------------------------------------------------------------
192 
193 //------------------------------------------------------------------------------
194 // Read DoFs
195 //------------------------------------------------------------------------------
196 inline __device__ void readDofs2d(const int elem, const int tidx,
197                                   const int tidy, const int comp,
198                                   const int nelem, const CeedScalar *d_U,
199                                   CeedScalar &U) {
200   U = (tidx<P1D
201        && tidy<P1D) ? d_U[tidx + tidy*P1D + comp*P1D*P1D +
202                           elem*BASIS_NCOMP*P1D*P1D] :
203       0.0;
204 }
205 
206 //------------------------------------------------------------------------------
207 // Write DoFs
208 //------------------------------------------------------------------------------
209 inline __device__ void writeDofs2d(const int elem, const int tidx,
210                                    const int tidy, const int comp,
211                                    const int nelem, const CeedScalar &r_V,
212                                    CeedScalar *d_V) {
213   if (tidx<P1D && tidy<P1D)
214     d_V[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D] = r_V;
215 }
216 
217 //------------------------------------------------------------------------------
218 // Read quadrature point data
219 //------------------------------------------------------------------------------
220 inline __device__ void readQuads2d(const int elem, const int tidx,
221                                    const int tidy, const int comp,
222                                    const int dim, const int nelem,
223                                    const CeedScalar *d_U, CeedScalar &U ) {
224   U = d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
225                dim*BASIS_NCOMP*nelem*Q1D*Q1D];
226 }
227 
228 //------------------------------------------------------------------------------
229 // Write quadrature point data
230 //------------------------------------------------------------------------------
231 inline __device__ void writeQuads2d(const int elem, const int tidx,
232                                     const int tidy, const int comp,
233                                     const int dim, const int nelem,
234                                     const CeedScalar &r_V, CeedScalar *d_V) {
235   d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
236            dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V;
237 }
238 
239 //------------------------------------------------------------------------------
240 // 2D tensor contraction x
241 //------------------------------------------------------------------------------
242 inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
243                                    const int tidy, const int tidz,
244                                    const CeedScalar &U, const CeedScalar *B,
245                                    CeedScalar &V) {
246   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
247   __syncthreads();
248   V = 0.0;
249   for (int i = 0; i < P1D; ++i)
250     V += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
251   __syncthreads();
252 }
253 
254 //------------------------------------------------------------------------------
255 // 2D tensor contraction y
256 //------------------------------------------------------------------------------
257 inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
258                                    const int tidy, const int tidz,
259                                    const CeedScalar &U, const CeedScalar *B,
260                                    CeedScalar &V) {
261   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
262   __syncthreads();
263   V = 0.0;
264   for (int i = 0; i < P1D; ++i)
265     V += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
266   __syncthreads();
267 }
268 
269 //------------------------------------------------------------------------------
270 // 2D transpose tensor contraction y
271 //------------------------------------------------------------------------------
272 inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
273     const int tidy, const int tidz,
274     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
275   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
276   __syncthreads();
277   V = 0.0;
278   if (tidy < P1D)
279     for (int i = 0; i < Q1D; ++i)
280       V += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
281   __syncthreads();
282 }
283 
284 //------------------------------------------------------------------------------
285 // 2D transpose tensor contraction x
286 //------------------------------------------------------------------------------
287 inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
288     const int tidy, const int tidz,
289     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
290   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
291   __syncthreads();
292   V = 0.0;
293   if (tidx < P1D)
294     for (int i = 0; i < Q1D; ++i)
295       V += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
296   __syncthreads();
297 }
298 
299 //------------------------------------------------------------------------------
300 // 2D interpolate to quadrature points
301 //------------------------------------------------------------------------------
302 inline __device__ void interp2d(const CeedInt nelem, const int transpose,
303                                 const CeedScalar *c_B,
304                                 const CeedScalar *__restrict__ d_U,
305                                 CeedScalar *__restrict__ d_V,
306                                 CeedScalar *slice) {
307   CeedScalar r_V;
308   CeedScalar r_t;
309 
310   const int tidx = threadIdx.x;
311   const int tidy = threadIdx.y;
312   const int tidz = threadIdx.z;
313   const int blockElem = tidz/BASIS_NCOMP;
314   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
315   const int comp = tidz%BASIS_NCOMP;
316 
317   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
318        elem += gridDim.x*elemsPerBlock) {
319     const int comp = tidz%BASIS_NCOMP;
320     r_V = 0.0;
321     r_t = 0.0;
322     if (!transpose) {
323       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
324       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
325       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
326       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
327     } else {
328       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
329       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
330       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
331       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
332     }
333   }
334 }
335 
336 //------------------------------------------------------------------------------
337 // 2D derivatives at quadrature points
338 //------------------------------------------------------------------------------
339 inline __device__ void grad2d(const CeedInt nelem, const int transpose,
340                               const CeedScalar *c_B, const CeedScalar *c_G,
341                               const CeedScalar *__restrict__ d_U,
342                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
343   CeedScalar r_U;
344   CeedScalar r_V;
345   CeedScalar r_t;
346 
347   const int tidx = threadIdx.x;
348   const int tidy = threadIdx.y;
349   const int tidz = threadIdx.z;
350   const int blockElem = tidz/BASIS_NCOMP;
351   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
352   const int comp = tidz%BASIS_NCOMP;
353   int dim;
354 
355   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
356        elem += gridDim.x*elemsPerBlock) {
357     if (!transpose) {
358       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
359       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
360       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
361       dim = 0;
362       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
363       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
364       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
365       dim = 1;
366       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
367     } else {
368       dim = 0;
369       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
370       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
371       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
372       dim = 1;
373       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
374       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
375       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
376       r_V += r_U;
377       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
378     }
379   }
380 }
381 
382 //------------------------------------------------------------------------------
383 // 2D quadrature weights
384 //------------------------------------------------------------------------------
385 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
386                          CeedScalar *w) {
387   const int i = threadIdx.x;
388   const int j = threadIdx.y;
389   const CeedScalar weight = qweight1d[i]*qweight1d[j];
390   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
391        elem += gridDim.x*blockDim.z) {
392     const int ind = elem*Q1D*Q1D + i + j*Q1D;
393     w[ind] = weight;
394   }
395 }
396 
397 //------------------------------------------------------------------------------
398 // 3D
399 //------------------------------------------------------------------------------
400 
401 //------------------------------------------------------------------------------
402 // Read DoFs
403 //------------------------------------------------------------------------------
404 inline __device__ void readDofs3d(const int elem, const int tidx,
405                                   const int tidy, const int comp,
406                                   const int nelem, const CeedScalar *d_U,
407                                   CeedScalar *r_U) {
408   for (int i = 0; i < P1D; i++)
409     r_U[i] = (tidx < P1D && tidy < P1D) ?
410               d_U[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
411                                       elem*BASIS_NCOMP*P1D*P1D*P1D] : 0.0;
412   for (int i = P1D; i < Q1D; i++)
413     r_U[i] = 0.0;
414 }
415 
416 //------------------------------------------------------------------------------
417 // Read quadrature point data
418 //------------------------------------------------------------------------------
419 inline __device__ void readQuads3d(const int elem, const int tidx,
420                                    const int tidy, const int comp,
421                                    const int dim, const int nelem,
422                                    const CeedScalar *d_U, CeedScalar *r_U) {
423   for (int i = 0; i < Q1D; i++)
424     r_U[i] = d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
425                  comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D];
426 }
427 
428 //------------------------------------------------------------------------------
429 // Write DoFs
430 //------------------------------------------------------------------------------
431 inline __device__ void writeDofs3d(const int elem, const int tidx,
432                                    const int tidy, const int comp,
433                                    const int nelem, const CeedScalar *r_V,
434                                    CeedScalar *d_V) {
435   if (tidx < P1D && tidy < P1D) {
436     for (int i = 0; i < P1D; i++)
437       d_V[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
438           elem*BASIS_NCOMP*P1D*P1D*P1D] = r_V[i];
439   }
440 }
441 
442 //------------------------------------------------------------------------------
443 // Write quadrature point data
444 //------------------------------------------------------------------------------
445 inline __device__ void writeQuads3d(const int elem, const int tidx,
446                                     const int tidy, const int comp,
447                                     const int dim, const int nelem,
448                                     const CeedScalar *r_V, CeedScalar *d_V) {
449   for (int i = 0; i < Q1D; i++)
450     d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
451         dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i];
452 }
453 
454 //------------------------------------------------------------------------------
455 // 3D tensor contract x
456 //------------------------------------------------------------------------------
457 inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
458                                    const int tidy, const int tidz,
459                                    const CeedScalar *U, const CeedScalar *B,
460                                    CeedScalar *V) {
461   for (int k = 0; k < P1D; ++k) {
462     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
463     __syncthreads();
464     V[k] = 0.0;
465     for (int i = 0; i < P1D; ++i)
466       V[k] += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
467     __syncthreads();
468   }
469 }
470 
471 //------------------------------------------------------------------------------
472 // 3D tensor contract y
473 //------------------------------------------------------------------------------
474 inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
475                                    const int tidy, const int tidz,
476                                    const CeedScalar *U, const CeedScalar *B,
477                                    CeedScalar *V) {
478   for (int k = 0; k < P1D; ++k) {
479     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
480     __syncthreads();
481     V[k] = 0.0;
482     for (int i = 0; i < P1D; ++i)
483       V[k] += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
484     __syncthreads();
485   }
486 }
487 
488 //------------------------------------------------------------------------------
489 // 3D tensor contract z
490 //------------------------------------------------------------------------------
491 inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
492                                    const int tidy, const int tidz,
493                                    const CeedScalar *U, const CeedScalar *B,
494                                    CeedScalar *V) {
495   for (int k = 0; k < Q1D; ++k) {
496     V[k] = 0.0;
497     for (int i = 0; i < P1D; ++i)
498       V[k] += B[i + k*P1D] * U[i]; // Contract z direction
499   }
500 }
501 
502 //------------------------------------------------------------------------------
503 // 3D transpose tensor contract z
504 //------------------------------------------------------------------------------
505 inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
506     const int tidy, const int tidz,
507     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
508   for (int k = 0; k < Q1D; ++k) {
509     V[k] = 0.0;
510     if (k < P1D)
511       for (int i = 0; i < Q1D; ++i)
512         V[k] += B[k + i*P1D] * U[i]; // Contract z direction
513   }
514 }
515 
516 //------------------------------------------------------------------------------
517 // 3D transpose tensor contract y
518 //------------------------------------------------------------------------------
519 inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
520     const int tidy, const int tidz,
521     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
522   for (int k = 0; k < P1D; ++k) {
523     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
524     __syncthreads();
525     V[k] = 0.0;
526     if (tidy < P1D)
527       for (int i = 0; i < Q1D; ++i)
528         V[k] += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
529     __syncthreads();
530   }
531 }
532 
533 //------------------------------------------------------------------------------
534 // 3D transpose tensor contract x
535 //------------------------------------------------------------------------------
536 inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
537     const int tidy, const int tidz,
538     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
539   for (int k = 0; k < P1D; ++k) {
540     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
541     __syncthreads();
542     V[k] = 0.0;
543     if (tidx < P1D)
544       for (int i = 0; i < Q1D; ++i)
545         V[k] += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
546     __syncthreads();
547   }
548 }
549 
550 //------------------------------------------------------------------------------
551 // 3D interpolate to quadrature points
552 //------------------------------------------------------------------------------
553 inline __device__ void interp3d(const CeedInt nelem, const int transpose,
554                                 const CeedScalar *c_B,
555                                 const CeedScalar *__restrict__ d_U,
556                                 CeedScalar *__restrict__ d_V,
557                                 CeedScalar *slice) {
558   CeedScalar r_V[Q1D];
559   CeedScalar r_t[Q1D];
560 
561   const int tidx = threadIdx.x;
562   const int tidy = threadIdx.y;
563   const int tidz = threadIdx.z;
564   const int blockElem = tidz/BASIS_NCOMP;
565   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
566   const int comp = tidz%BASIS_NCOMP;
567 
568   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
569        elem += gridDim.x*elemsPerBlock) {
570     for (int i = 0; i < Q1D; ++i) {
571       r_V[i] = 0.0;
572       r_t[i] = 0.0;
573     }
574     if (!transpose) {
575       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
576       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
577       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
578       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
579       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
580     } else {
581       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
582       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
583       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
584       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
585       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
586     }
587   }
588 }
589 
590 //------------------------------------------------------------------------------
591 // 3D derivatives at quadrature points
592 //------------------------------------------------------------------------------
593 inline __device__ void grad3d(const CeedInt nelem, const int transpose,
594                               const CeedScalar *c_B, const CeedScalar *c_G,
595                               const CeedScalar *__restrict__ d_U,
596                               CeedScalar *__restrict__ d_V,
597                               CeedScalar *slice) {
598   // Use P1D for one of these
599   CeedScalar r_U[Q1D];
600   CeedScalar r_V[Q1D];
601   CeedScalar r_t[Q1D];
602 
603   const int tidx = threadIdx.x;
604   const int tidy = threadIdx.y;
605   const int tidz = threadIdx.z;
606   const int blockElem = tidz/BASIS_NCOMP;
607   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
608   const int comp = tidz%BASIS_NCOMP;
609   int dim;
610 
611   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
612        elem += gridDim.x*elemsPerBlock) {
613     if (!transpose) {
614       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
615       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
616       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
617       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
618       dim = 0;
619       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
620       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
621       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
622       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
623       dim = 1;
624       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
625       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
626       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
627       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
628       dim = 2;
629       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
630     } else {
631       dim = 0;
632       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
633       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
634       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
635       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
636       dim = 1;
637       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
638       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
639       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
640       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
641       add(r_V, r_t);
642       dim = 2;
643       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
644       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
645       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
646       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
647       add(r_V, r_t);
648       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
649     }
650   }
651 }
652 
653 //------------------------------------------------------------------------------
654 // 3D quadrature weights
655 //------------------------------------------------------------------------------
656 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
657                          CeedScalar *w) {
658   const int i = threadIdx.x;
659   const int j = threadIdx.y;
660   const int k = threadIdx.z;
661   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
662   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
663     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
664     w[ind] = weight;
665   }
666 }
667 
668 
669 //------------------------------------------------------------------------------
670 // Basis kernels
671 //------------------------------------------------------------------------------
672 
673 //------------------------------------------------------------------------------
674 // Interp kernel by dim
675 //------------------------------------------------------------------------------
676 extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
677                                   const CeedScalar *c_B,
678                                   const CeedScalar *__restrict__ d_U,
679                                   CeedScalar *__restrict__ d_V) {
680   extern __shared__ double slice[];
681   if (BASIS_DIM == 1) {
682     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
683   } else if (BASIS_DIM == 2) {
684     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
685   } else if (BASIS_DIM == 3) {
686     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
687   }
688 }
689 
690 //------------------------------------------------------------------------------
691 // Grad kernel by dim
692 //------------------------------------------------------------------------------
693 extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
694                                 const CeedScalar *c_B, const CeedScalar *c_G,
695                                 const CeedScalar *__restrict__ d_U,
696                                 CeedScalar *__restrict__ d_V) {
697   extern __shared__ double slice[];
698   if (BASIS_DIM == 1) {
699     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
700   } else if (BASIS_DIM == 2) {
701     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
702   } else if (BASIS_DIM == 3) {
703     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
704   }
705 }
706 
707 //------------------------------------------------------------------------------
708 // Weight kernels by dim
709 //------------------------------------------------------------------------------
710 extern "C" __global__ void weight(const CeedInt nelem,
711                                   const CeedScalar *__restrict__ qweight1d,
712                                   CeedScalar *__restrict__ v) {
713   if (BASIS_DIM == 1) {
714     weight1d(nelem, qweight1d, v);
715   } else if (BASIS_DIM == 2) {
716     weight2d(nelem, qweight1d, v);
717   } else if (BASIS_DIM == 3) {
718     weight3d(nelem, qweight1d, v);
719   }
720 }
721 
722 );
723 // *INDENT-ON*
724 
725 //------------------------------------------------------------------------------
726 // Device initalization
727 //------------------------------------------------------------------------------
728 int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
729                        CeedScalar **c_B);
730 int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
731                            CeedInt Q1d, CeedScalar **c_B_ptr,
732                            CeedScalar **c_G_ptr);
733 
734 //------------------------------------------------------------------------------
735 // Apply basis
736 //------------------------------------------------------------------------------
737 int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem,
738                                      CeedTransposeMode tmode,
739                                      CeedEvalMode emode, CeedVector u,
740                                      CeedVector v) {
741   int ierr;
742   Ceed ceed;
743   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
744   Ceed_Cuda_shared *ceed_Cuda;
745   CeedGetData(ceed, (void *) &ceed_Cuda); CeedChk(ierr);
746   CeedBasis_Cuda_shared *data;
747   CeedBasisGetData(basis, (void *)&data); CeedChk(ierr);
748   const CeedInt transpose = tmode == CEED_TRANSPOSE;
749   CeedInt dim, ncomp;
750   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
751   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
752 
753   // Read vectors
754   const CeedScalar *d_u;
755   CeedScalar *d_v;
756   if (emode != CEED_EVAL_WEIGHT) {
757     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
758   }
759   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
760 
761   // Clear v for transpose mode
762   if (tmode == CEED_TRANSPOSE) {
763     CeedInt length;
764     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
765     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
766   }
767 
768   // Apply basis operation
769   switch (emode) {
770   case CEED_EVAL_INTERP: {
771     CeedInt P1d, Q1d;
772     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
773     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
774     ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
775     CeedChk(ierr);
776     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
777                           &d_u, &d_v
778                          };
779     if (dim == 1) {
780       CeedInt elemsPerBlock = 32;
781       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
782                                              ? 1 : 0 );
783       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
784       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, 1,
785                                         elemsPerBlock, sharedMem,
786                                         interpargs); CeedChk(ierr);
787     } else if (dim == 2) {
788       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
789       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
790       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
791                                              ? 1 : 0 );
792       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
793       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d,
794                                         ncomp*elemsPerBlock, sharedMem,
795                                         interpargs); CeedChk(ierr);
796     } else if (dim == 3) {
797       CeedInt elemsPerBlock = 1;
798       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
799                                              ? 1 : 0 );
800       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
801       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d,
802                                         ncomp*elemsPerBlock, sharedMem,
803                                         interpargs); CeedChk(ierr);
804     }
805   } break;
806   case CEED_EVAL_GRAD: {
807     CeedInt P1d, Q1d;
808     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
809     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
810     ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
811                                   Q1d, &data->c_B, &data->c_G);
812     CeedChk(ierr);
813     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
814                         &data->c_G, &d_u, &d_v
815                        };
816     if (dim == 1) {
817       CeedInt elemsPerBlock = 32;
818       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
819                                              ? 1 : 0 );
820       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
821       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, 1,
822                                         elemsPerBlock, sharedMem, gradargs);
823       CeedChk(ierr);
824     } else if (dim == 2) {
825       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
826       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
827       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
828                                              ? 1 : 0 );
829       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
830       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d,
831                                         ncomp*elemsPerBlock, sharedMem,
832                                         gradargs); CeedChk(ierr);
833     } else if (dim == 3) {
834       CeedInt elemsPerBlock = 1;
835       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
836                                              ? 1 : 0 );
837       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
838       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d,
839                                         ncomp*elemsPerBlock, sharedMem,
840                                         gradargs); CeedChk(ierr);
841     }
842   } break;
843   case CEED_EVAL_WEIGHT: {
844     CeedInt Q1d;
845     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
846     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
847     if (dim == 1) {
848       const CeedInt elemsPerBlock = 32/Q1d;
849       const CeedInt gridsize = nelem/elemsPerBlock + ( (
850                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
851       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d,
852                                   elemsPerBlock, 1, weightargs);
853       CeedChk(ierr);
854     } else if (dim == 2) {
855       const CeedInt optElems = 32/(Q1d*Q1d);
856       const CeedInt elemsPerBlock = optElems>0?optElems:1;
857       const CeedInt gridsize = nelem/elemsPerBlock + ( (
858                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
859       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d,
860                                   elemsPerBlock, weightargs);
861       CeedChk(ierr);
862     } else if (dim == 3) {
863       const CeedInt gridsize = nelem;
864       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
865                                   weightargs);
866       CeedChk(ierr);
867     }
868   } break;
869   // LCOV_EXCL_START
870   // Evaluate the divergence to/from the quadrature points
871   case CEED_EVAL_DIV:
872     return CeedError(ceed, 1, "CEED_EVAL_DIV not supported");
873   // Evaluate the curl to/from the quadrature points
874   case CEED_EVAL_CURL:
875     return CeedError(ceed, 1, "CEED_EVAL_CURL not supported");
876   // Take no action, BasisApply should not have been called
877   case CEED_EVAL_NONE:
878     return CeedError(ceed, 1,
879                      "CEED_EVAL_NONE does not make sense in this context");
880     // LCOV_EXCL_STOP
881   }
882 
883   // Restore vectors
884   if (emode != CEED_EVAL_WEIGHT) {
885     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
886   }
887   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
888   return 0;
889 }
890 
891 //------------------------------------------------------------------------------
892 // Destroy basis
893 //------------------------------------------------------------------------------
894 static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
895   int ierr;
896   Ceed ceed;
897   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
898 
899   CeedBasis_Cuda_shared *data;
900   ierr = CeedBasisGetData(basis, (void *) &data); CeedChk(ierr);
901 
902   CeedChk_Cu(ceed, cuModuleUnload(data->module));
903 
904   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr);
905   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr);
906   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr);
907 
908   ierr = CeedFree(&data); CeedChk(ierr);
909 
910   return 0;
911 }
912 
913 //------------------------------------------------------------------------------
914 // Create tensor basis
915 //------------------------------------------------------------------------------
916 int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
917                                         const CeedScalar *interp1d,
918                                         const CeedScalar *grad1d,
919                                         const CeedScalar *qref1d,
920                                         const CeedScalar *qweight1d,
921                                         CeedBasis basis) {
922   int ierr;
923   Ceed ceed;
924   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
925   if (Q1d<P1d) {
926     return CeedError(ceed, 1, "Backend does not implement underintegrated basis.");
927   }
928   CeedBasis_Cuda_shared *data;
929   ierr = CeedCalloc(1, &data); CeedChk(ierr);
930 
931   // Copy basis data to GPU
932   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
933   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr);
934   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
935                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
936 
937   const CeedInt iBytes = qBytes * P1d;
938   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr);
939   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
940                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
941 
942   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr);
943   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
944                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
945 
946   // Compute collocated gradient and copy to GPU
947   data->d_collograd1d = NULL;
948   if (dim == 3 && Q1d >= P1d) {
949     CeedScalar *collograd1d;
950     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChk(ierr);
951     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChk(ierr);
952     ierr = cudaMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
953     CeedChk_Cu(ceed, ierr);
954     ierr = cudaMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
955                       cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
956   }
957 
958   // Compile basis kernels
959   CeedInt ncomp;
960   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
961   ierr = CeedCompileCuda(ceed, kernelsShared, &data->module, 7,
962                          "Q1D", Q1d,
963                          "P1D", P1d,
964                          "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
965                              Q1d : P1d, dim),
966                          "BASIS_DIM", dim,
967                          "BASIS_NCOMP", ncomp,
968                          "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
969                          "BASIS_NQPT", CeedIntPow(Q1d, dim)
970                         ); CeedChk(ierr);
971   ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp);
972   CeedChk(ierr);
973   ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad);
974   CeedChk(ierr);
975   ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight);
976   CeedChk(ierr);
977 
978   ierr = CeedBasisSetData(basis, (void *)&data); CeedChk(ierr);
979 
980   // Register backend functions
981   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
982                                 CeedBasisApplyTensor_Cuda_shared);
983   CeedChk(ierr);
984   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
985                                 CeedBasisDestroy_Cuda_shared); CeedChk(ierr);
986   return 0;
987 }
988 //------------------------------------------------------------------------------
989