xref: /libCEED/backends/cuda-shared/ceed-cuda-shared-basis.c (revision cd11335e55ed5306fc93608d50d5700c3a391b92)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include "ceed-cuda-shared.h"
18 
19 //------------------------------------------------------------------------------
20 // Shared mem kernels
21 //------------------------------------------------------------------------------
22 // *INDENT-OFF*
23 static const char *kernelsShared = QUOTE(
24 
25 //------------------------------------------------------------------------------
26 // Sum input into output
27 //------------------------------------------------------------------------------
28 inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
29   for (int i = 0; i < Q1D; i++)
30     r_V[i] += r_U[i];
31 }
32 
33 //------------------------------------------------------------------------------
34 // 1D
35 //------------------------------------------------------------------------------
36 
37 //------------------------------------------------------------------------------
38 // Read DoFs
39 //------------------------------------------------------------------------------
40 inline __device__ void readDofs1d(const int elem, const int tidx,
41                                   const int tidy, const int tidz,const int comp,
42                                   const int nelem, const CeedScalar *d_U,
43                                   CeedScalar *slice) {
44   for (int i = 0; i < P1D; i++)
45     slice[i + tidz*Q1D] = d_U[i + elem*P1D + comp*P1D*nelem];
46   for (int i = P1D; i < Q1D; i++)
47     slice[i + tidz*Q1D] = 0.0;
48 }
49 
50 //------------------------------------------------------------------------------
51 // Write DoFs
52 //------------------------------------------------------------------------------
53 inline __device__ void writeDofs1d(const int elem, const int tidx,
54                                    const int tidy, const int comp,
55                                    const int nelem, const CeedScalar &r_V,
56                                    CeedScalar *d_V) {
57   if (tidx<P1D)
58     d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V;
59 }
60 
61 //------------------------------------------------------------------------------
62 // Read quadrature point data
63 //------------------------------------------------------------------------------
64 inline __device__ void readQuads1d(const int elem, const int tidx,
65                                    const int tidy, const int tidz, const int comp,
66                                    const int dim, const int nelem,
67                                    const CeedScalar *d_U, CeedScalar *slice) {
68   for (int i = 0; i < Q1D; i++)
69     slice[i + tidz*Q1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
70                             dim*BASIS_NCOMP*nelem*Q1D];
71 }
72 
73 //------------------------------------------------------------------------------
74 // Write quadrature point data
75 //------------------------------------------------------------------------------
76 inline __device__ void writeQuads1d(const int elem, const int tidx,
77                                     const int tidy, const int comp,
78                                     const int dim, const int nelem,
79                                     const CeedScalar &r_V, CeedScalar *d_V) {
80   d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
81 }
82 
83 //------------------------------------------------------------------------------
84 // 1D tensor contraction
85 //------------------------------------------------------------------------------
86 inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
87                                    const int tidy, const int tidz,
88                                    const CeedScalar &U, const CeedScalar *B,
89                                    CeedScalar &V) {
90   V = 0.0;
91   for (int i = 0; i < P1D; ++i)
92     V += B[i + tidx*P1D] * slice[i + tidz*Q1D]; // Contract x direction
93 }
94 
95 //------------------------------------------------------------------------------
96 // 1D transpose tensor contraction
97 //------------------------------------------------------------------------------
98 inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
99     const int tidy, const int tidz,
100     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
101   V = 0.0;
102   for (int i = 0; i < Q1D; ++i)
103     V += B[tidx + i*P1D] * slice[i + tidz*Q1D]; // Contract x direction
104 }
105 
106 //------------------------------------------------------------------------------
107 // 1D interpolate to quadrature points
108 //------------------------------------------------------------------------------
109 inline __device__ void interp1d(const CeedInt nelem, const int transpose,
110                                 const CeedScalar *c_B,
111                                 const CeedScalar *__restrict__ d_U,
112                                 CeedScalar *__restrict__ d_V,
113                                 CeedScalar *slice) {
114   CeedScalar r_V;
115   CeedScalar r_t;
116 
117   const int tidx = threadIdx.x;
118   const int tidy = threadIdx.y;
119   const int tidz = threadIdx.z;
120 
121 
122   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
123        elem += gridDim.x*blockDim.z) {
124     for (int comp = 0; comp < BASIS_NCOMP; comp++) {
125       if (!transpose) {
126         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
127         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
128         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
129       } else {
130         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
131         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
132         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
133       }
134     }
135   }
136 }
137 
138 //------------------------------------------------------------------------------
139 // 1D derivatives at quadrature points
140 //------------------------------------------------------------------------------
141 inline __device__ void grad1d(const CeedInt nelem, const int transpose,
142                               const CeedScalar *c_B, const CeedScalar *c_G,
143                               const CeedScalar *__restrict__ d_U,
144                               CeedScalar *__restrict__ d_V,
145                               CeedScalar *slice) {
146   CeedScalar r_U;
147   CeedScalar r_V;
148 
149   const int tidx = threadIdx.x;
150   const int tidy = threadIdx.y;
151   const int tidz = threadIdx.z;
152   int dim;
153 
154   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
155        elem += gridDim.x*blockDim.z) {
156     for(int comp = 0; comp < BASIS_NCOMP; comp++) {
157       if (!transpose) {
158         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
159         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
160         dim = 0;
161         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
162       } else {
163         dim = 0;
164         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
165         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
166         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
167       }
168     }
169   }
170 }
171 
172 //------------------------------------------------------------------------------
173 // 1D Quadrature weights
174 //------------------------------------------------------------------------------
175 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
176                          CeedScalar *w) {
177   const int tid = threadIdx.x;
178   const CeedScalar weight = qweight1d[tid];
179   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
180        elem += gridDim.x*blockDim.y) {
181     const int ind = elem*Q1D + tid;
182     w[ind] = weight;
183   }
184 }
185 
186 //------------------------------------------------------------------------------
187 // 2D
188 //------------------------------------------------------------------------------
189 
190 //------------------------------------------------------------------------------
191 // Read DoFs
192 //------------------------------------------------------------------------------
193 inline __device__ void readDofs2d(const int elem, const int tidx,
194                                   const int tidy, const int comp,
195                                   const int nelem, const CeedScalar *d_U,
196                                   CeedScalar &U) {
197   U = (tidx<P1D && tidy<P1D) ?
198       d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0;
199 }
200 
201 //------------------------------------------------------------------------------
202 // Write DoFs
203 //------------------------------------------------------------------------------
204 inline __device__ void writeDofs2d(const int elem, const int tidx,
205                                    const int tidy, const int comp,
206                                    const int nelem, const CeedScalar &r_V,
207                                    CeedScalar *d_V) {
208   if (tidx<P1D && tidy<P1D)
209     d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V;
210 }
211 
212 //------------------------------------------------------------------------------
213 // Read quadrature point data
214 //------------------------------------------------------------------------------
215 inline __device__ void readQuads2d(const int elem, const int tidx,
216                                    const int tidy, const int comp,
217                                    const int dim, const int nelem,
218                                    const CeedScalar *d_U, CeedScalar &U ) {
219   U = d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
220                dim*BASIS_NCOMP*nelem*Q1D*Q1D];
221 }
222 
223 //------------------------------------------------------------------------------
224 // Write quadrature point data
225 //------------------------------------------------------------------------------
226 inline __device__ void writeQuads2d(const int elem, const int tidx,
227                                     const int tidy, const int comp,
228                                     const int dim, const int nelem,
229                                     const CeedScalar &r_V, CeedScalar *d_V) {
230   d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
231            dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V;
232 }
233 
234 //------------------------------------------------------------------------------
235 // 2D tensor contraction x
236 //------------------------------------------------------------------------------
237 inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
238                                    const int tidy, const int tidz,
239                                    const CeedScalar &U, const CeedScalar *B,
240                                    CeedScalar &V) {
241   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
242   __syncthreads();
243   V = 0.0;
244   for (int i = 0; i < P1D; ++i)
245     V += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
246   __syncthreads();
247 }
248 
249 //------------------------------------------------------------------------------
250 // 2D tensor contraction y
251 //------------------------------------------------------------------------------
252 inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
253                                    const int tidy, const int tidz,
254                                    const CeedScalar &U, const CeedScalar *B,
255                                    CeedScalar &V) {
256   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
257   __syncthreads();
258   V = 0.0;
259   for (int i = 0; i < P1D; ++i)
260     V += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
261   __syncthreads();
262 }
263 
264 //------------------------------------------------------------------------------
265 // 2D transpose tensor contraction y
266 //------------------------------------------------------------------------------
267 inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
268     const int tidy, const int tidz,
269     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
270   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
271   __syncthreads();
272   V = 0.0;
273   if (tidy < P1D)
274     for (int i = 0; i < Q1D; ++i)
275       V += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
276   __syncthreads();
277 }
278 
279 //------------------------------------------------------------------------------
280 // 2D transpose tensor contraction x
281 //------------------------------------------------------------------------------
282 inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
283     const int tidy, const int tidz,
284     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
285   slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U;
286   __syncthreads();
287   V = 0.0;
288   if (tidx < P1D)
289     for (int i = 0; i < Q1D; ++i)
290       V += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
291   __syncthreads();
292 }
293 
294 //------------------------------------------------------------------------------
295 // 2D interpolate to quadrature points
296 //------------------------------------------------------------------------------
297 inline __device__ void interp2d(const CeedInt nelem, const int transpose,
298                                 const CeedScalar *c_B,
299                                 const CeedScalar *__restrict__ d_U,
300                                 CeedScalar *__restrict__ d_V,
301                                 CeedScalar *slice) {
302   CeedScalar r_V;
303   CeedScalar r_t;
304 
305   const int tidx = threadIdx.x;
306   const int tidy = threadIdx.y;
307   const int tidz = threadIdx.z;
308   const int blockElem = tidz/BASIS_NCOMP;
309   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
310   const int comp = tidz%BASIS_NCOMP;
311 
312   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
313        elem += gridDim.x*elemsPerBlock) {
314     const int comp = tidz%BASIS_NCOMP;
315     r_V = 0.0;
316     r_t = 0.0;
317     if (!transpose) {
318       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
319       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
320       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
321       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
322     } else {
323       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
324       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
325       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
326       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
327     }
328   }
329 }
330 
331 //------------------------------------------------------------------------------
332 // 2D derivatives at quadrature points
333 //------------------------------------------------------------------------------
334 inline __device__ void grad2d(const CeedInt nelem, const int transpose,
335                               const CeedScalar *c_B, const CeedScalar *c_G,
336                               const CeedScalar *__restrict__ d_U,
337                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
338   CeedScalar r_U;
339   CeedScalar r_V;
340   CeedScalar r_t;
341 
342   const int tidx = threadIdx.x;
343   const int tidy = threadIdx.y;
344   const int tidz = threadIdx.z;
345   const int blockElem = tidz/BASIS_NCOMP;
346   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
347   const int comp = tidz%BASIS_NCOMP;
348   int dim;
349 
350   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
351        elem += gridDim.x*elemsPerBlock) {
352     if (!transpose) {
353       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
354       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
355       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
356       dim = 0;
357       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
358       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
359       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
360       dim = 1;
361       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
362     } else {
363       dim = 0;
364       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
365       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
366       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
367       dim = 1;
368       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
369       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
370       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
371       r_V += r_U;
372       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
373     }
374   }
375 }
376 
377 //------------------------------------------------------------------------------
378 // 2D quadrature weights
379 //------------------------------------------------------------------------------
380 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
381                          CeedScalar *w) {
382   const int i = threadIdx.x;
383   const int j = threadIdx.y;
384   const CeedScalar weight = qweight1d[i]*qweight1d[j];
385   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
386        elem += gridDim.x*blockDim.z) {
387     const int ind = elem*Q1D*Q1D + i + j*Q1D;
388     w[ind] = weight;
389   }
390 }
391 
392 //------------------------------------------------------------------------------
393 // 3D
394 //------------------------------------------------------------------------------
395 
396 //------------------------------------------------------------------------------
397 // Read DoFs
398 //------------------------------------------------------------------------------
399 inline __device__ void readDofs3d(const int elem, const int tidx,
400                                   const int tidy, const int comp,
401                                   const int nelem, const CeedScalar *d_U,
402                                   CeedScalar *r_U) {
403   for (int i = 0; i < P1D; i++)
404     r_U[i] = (tidx < P1D && tidy < P1D) ?
405               d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
406                   comp*P1D*P1D*P1D*nelem] : 0.0;
407   for (int i = P1D; i < Q1D; i++)
408     r_U[i] = 0.0;
409 }
410 
411 //------------------------------------------------------------------------------
412 // Write DoFs
413 //------------------------------------------------------------------------------
414 inline __device__ void writeDofs3d(const int elem, const int tidx,
415                                    const int tidy, const int comp,
416                                    const int nelem, const CeedScalar *r_V,
417                                    CeedScalar *d_V) {
418   if (tidx < P1D && tidy < P1D) {
419     for (int i = 0; i < P1D; i++)
420       d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
421           comp*P1D*P1D*P1D*nelem] = r_V[i];
422   }
423 }
424 
425 //------------------------------------------------------------------------------
426 // Read quadrature point data
427 //------------------------------------------------------------------------------
428 inline __device__ void readQuads3d(const int elem, const int tidx,
429                                    const int tidy, const int comp,
430                                    const int dim, const int nelem,
431                                    const CeedScalar *d_U, CeedScalar *r_U) {
432   for (int i = 0; i < Q1D; i++)
433     r_U[i] = d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
434                  comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D];
435 }
436 
437 //------------------------------------------------------------------------------
438 // Write quadrature point data
439 //------------------------------------------------------------------------------
440 inline __device__ void writeQuads3d(const int elem, const int tidx,
441                                     const int tidy, const int comp,
442                                     const int dim, const int nelem,
443                                     const CeedScalar *r_V, CeedScalar *d_V) {
444   for (int i = 0; i < Q1D; i++)
445     d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
446         dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i];
447 }
448 
449 //------------------------------------------------------------------------------
450 // 3D tensor contract x
451 //------------------------------------------------------------------------------
452 inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
453                                    const int tidy, const int tidz,
454                                    const CeedScalar *U, const CeedScalar *B,
455                                    CeedScalar *V) {
456   for (int k = 0; k < P1D; ++k) {
457     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
458     __syncthreads();
459     V[k] = 0.0;
460     for (int i = 0; i < P1D; ++i)
461       V[k] += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
462     __syncthreads();
463   }
464 }
465 
466 //------------------------------------------------------------------------------
467 // 3D tensor contract y
468 //------------------------------------------------------------------------------
469 inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
470                                    const int tidy, const int tidz,
471                                    const CeedScalar *U, const CeedScalar *B,
472                                    CeedScalar *V) {
473   for (int k = 0; k < P1D; ++k) {
474     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
475     __syncthreads();
476     V[k] = 0.0;
477     for (int i = 0; i < P1D; ++i)
478       V[k] += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
479     __syncthreads();
480   }
481 }
482 
483 //------------------------------------------------------------------------------
484 // 3D tensor contract z
485 //------------------------------------------------------------------------------
486 inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
487                                    const int tidy, const int tidz,
488                                    const CeedScalar *U, const CeedScalar *B,
489                                    CeedScalar *V) {
490   for (int k = 0; k < Q1D; ++k) {
491     V[k] = 0.0;
492     for (int i = 0; i < P1D; ++i)
493       V[k] += B[i + k*P1D] * U[i]; // Contract z direction
494   }
495 }
496 
497 //------------------------------------------------------------------------------
498 // 3D transpose tensor contract z
499 //------------------------------------------------------------------------------
500 inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
501     const int tidy, const int tidz,
502     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
503   for (int k = 0; k < Q1D; ++k) {
504     V[k] = 0.0;
505     if (k < P1D)
506       for (int i = 0; i < Q1D; ++i)
507         V[k] += B[k + i*P1D] * U[i]; // Contract z direction
508   }
509 }
510 
511 //------------------------------------------------------------------------------
512 // 3D transpose tensor contract y
513 //------------------------------------------------------------------------------
514 inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
515     const int tidy, const int tidz,
516     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
517   for (int k = 0; k < P1D; ++k) {
518     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
519     __syncthreads();
520     V[k] = 0.0;
521     if (tidy < P1D)
522       for (int i = 0; i < Q1D; ++i)
523         V[k] += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D]; // Contract y direction
524     __syncthreads();
525   }
526 }
527 
528 //------------------------------------------------------------------------------
529 // 3D transpose tensor contract x
530 //------------------------------------------------------------------------------
531 inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
532     const int tidy, const int tidz,
533     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
534   for (int k = 0; k < P1D; ++k) {
535     slice[tidx + tidy*Q1D + tidz*Q1D*Q1D] = U[k];
536     __syncthreads();
537     V[k] = 0.0;
538     if (tidx < P1D)
539       for (int i = 0; i < Q1D; ++i)
540         V[k] += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D]; // Contract x direction
541     __syncthreads();
542   }
543 }
544 
545 //------------------------------------------------------------------------------
546 // 3D interpolate to quadrature points
547 //------------------------------------------------------------------------------
548 inline __device__ void interp3d(const CeedInt nelem, const int transpose,
549                                 const CeedScalar *c_B,
550                                 const CeedScalar *__restrict__ d_U,
551                                 CeedScalar *__restrict__ d_V,
552                                 CeedScalar *slice) {
553   CeedScalar r_V[Q1D];
554   CeedScalar r_t[Q1D];
555 
556   const int tidx = threadIdx.x;
557   const int tidy = threadIdx.y;
558   const int tidz = threadIdx.z;
559   const int blockElem = tidz/BASIS_NCOMP;
560   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
561   const int comp = tidz%BASIS_NCOMP;
562 
563   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
564        elem += gridDim.x*elemsPerBlock) {
565     for (int i = 0; i < Q1D; ++i) {
566       r_V[i] = 0.0;
567       r_t[i] = 0.0;
568     }
569     if (!transpose) {
570       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
571       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
572       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
573       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
574       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
575     } else {
576       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
577       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
578       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
579       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
580       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
581     }
582   }
583 }
584 
585 //------------------------------------------------------------------------------
586 // 3D derivatives at quadrature points
587 //------------------------------------------------------------------------------
588 inline __device__ void grad3d(const CeedInt nelem, const int transpose,
589                               const CeedScalar *c_B, const CeedScalar *c_G,
590                               const CeedScalar *__restrict__ d_U,
591                               CeedScalar *__restrict__ d_V,
592                               CeedScalar *slice) {
593   // Use P1D for one of these
594   CeedScalar r_U[Q1D];
595   CeedScalar r_V[Q1D];
596   CeedScalar r_t[Q1D];
597 
598   const int tidx = threadIdx.x;
599   const int tidy = threadIdx.y;
600   const int tidz = threadIdx.z;
601   const int blockElem = tidz/BASIS_NCOMP;
602   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
603   const int comp = tidz%BASIS_NCOMP;
604   int dim;
605 
606   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
607        elem += gridDim.x*elemsPerBlock) {
608     if (!transpose) {
609       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
610       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
611       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
612       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
613       dim = 0;
614       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
615       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
616       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
617       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
618       dim = 1;
619       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
620       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
621       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
622       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
623       dim = 2;
624       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
625     } else {
626       dim = 0;
627       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
628       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
629       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
630       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
631       dim = 1;
632       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
633       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
634       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
635       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
636       add(r_V, r_t);
637       dim = 2;
638       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
639       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
640       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
641       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
642       add(r_V, r_t);
643       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
644     }
645   }
646 }
647 
648 //------------------------------------------------------------------------------
649 // 3D quadrature weights
650 //------------------------------------------------------------------------------
651 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
652                          CeedScalar *w) {
653   const int i = threadIdx.x;
654   const int j = threadIdx.y;
655   const int k = threadIdx.z;
656   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
657   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
658     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
659     w[ind] = weight;
660   }
661 }
662 
663 
664 //------------------------------------------------------------------------------
665 // Basis kernels
666 //------------------------------------------------------------------------------
667 
668 //------------------------------------------------------------------------------
669 // Interp kernel by dim
670 //------------------------------------------------------------------------------
671 extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
672                                   const CeedScalar *c_B,
673                                   const CeedScalar *__restrict__ d_U,
674                                   CeedScalar *__restrict__ d_V) {
675   extern __shared__ double slice[];
676   if (BASIS_DIM == 1) {
677     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
678   } else if (BASIS_DIM == 2) {
679     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
680   } else if (BASIS_DIM == 3) {
681     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
682   }
683 }
684 
685 //------------------------------------------------------------------------------
686 // Grad kernel by dim
687 //------------------------------------------------------------------------------
688 extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
689                                 const CeedScalar *c_B, const CeedScalar *c_G,
690                                 const CeedScalar *__restrict__ d_U,
691                                 CeedScalar *__restrict__ d_V) {
692   extern __shared__ double slice[];
693   if (BASIS_DIM == 1) {
694     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
695   } else if (BASIS_DIM == 2) {
696     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
697   } else if (BASIS_DIM == 3) {
698     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
699   }
700 }
701 
702 //------------------------------------------------------------------------------
703 // Weight kernels by dim
704 //------------------------------------------------------------------------------
705 extern "C" __global__ void weight(const CeedInt nelem,
706                                   const CeedScalar *__restrict__ qweight1d,
707                                   CeedScalar *__restrict__ v) {
708   if (BASIS_DIM == 1) {
709     weight1d(nelem, qweight1d, v);
710   } else if (BASIS_DIM == 2) {
711     weight2d(nelem, qweight1d, v);
712   } else if (BASIS_DIM == 3) {
713     weight3d(nelem, qweight1d, v);
714   }
715 }
716 
717 );
718 // *INDENT-ON*
719 
720 //------------------------------------------------------------------------------
721 // Device initalization
722 //------------------------------------------------------------------------------
723 int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
724                        CeedScalar **c_B);
725 int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
726                            CeedInt Q1d, CeedScalar **c_B_ptr,
727                            CeedScalar **c_G_ptr);
728 
729 //------------------------------------------------------------------------------
730 // Apply basis
731 //------------------------------------------------------------------------------
732 int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem,
733                                      CeedTransposeMode tmode,
734                                      CeedEvalMode emode, CeedVector u,
735                                      CeedVector v) {
736   int ierr;
737   Ceed ceed;
738   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
739   Ceed_Cuda_shared *ceed_Cuda;
740   CeedGetData(ceed, (void *) &ceed_Cuda); CeedChk(ierr);
741   CeedBasis_Cuda_shared *data;
742   CeedBasisGetData(basis, (void *)&data); CeedChk(ierr);
743   const CeedInt transpose = tmode == CEED_TRANSPOSE;
744   CeedInt dim, ncomp;
745   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
746   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
747 
748   // Read vectors
749   const CeedScalar *d_u;
750   CeedScalar *d_v;
751   if (emode != CEED_EVAL_WEIGHT) {
752     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
753   }
754   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
755 
756   // Clear v for transpose mode
757   if (tmode == CEED_TRANSPOSE) {
758     CeedInt length;
759     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
760     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
761   }
762 
763   // Apply basis operation
764   switch (emode) {
765   case CEED_EVAL_INTERP: {
766     CeedInt P1d, Q1d;
767     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
768     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
769     ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
770     CeedChk(ierr);
771     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
772                           &d_u, &d_v
773                          };
774     if (dim == 1) {
775       CeedInt elemsPerBlock = 32;
776       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
777                                              ? 1 : 0 );
778       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
779       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, 1,
780                                         elemsPerBlock, sharedMem,
781                                         interpargs); CeedChk(ierr);
782     } else if (dim == 2) {
783       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
784       // elemsPerBlock must be at least 1
785       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
786       if (elemsPerBlock == 0) elemsPerBlock = 1;
787       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
788                                              ? 1 : 0 );
789       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
790       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d,
791                                         ncomp*elemsPerBlock, sharedMem,
792                                         interpargs); CeedChk(ierr);
793     } else if (dim == 3) {
794       CeedInt elemsPerBlock = 1;
795       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
796                                              ? 1 : 0 );
797       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
798       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d,
799                                         ncomp*elemsPerBlock, sharedMem,
800                                         interpargs); CeedChk(ierr);
801     }
802   } break;
803   case CEED_EVAL_GRAD: {
804     CeedInt P1d, Q1d;
805     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
806     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
807     ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
808                                   Q1d, &data->c_B, &data->c_G);
809     CeedChk(ierr);
810     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
811                         &data->c_G, &d_u, &d_v
812                        };
813     if (dim == 1) {
814       CeedInt elemsPerBlock = 32;
815       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
816                                              ? 1 : 0 );
817       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
818       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, 1,
819                                         elemsPerBlock, sharedMem, gradargs);
820       CeedChk(ierr);
821     } else if (dim == 2) {
822       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
823       // elemsPerBlock must be at least 1
824       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
825       if (elemsPerBlock == 0) elemsPerBlock = 1;
826       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
827                                              ? 1 : 0 );
828       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
829       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d,
830                                         ncomp*elemsPerBlock, sharedMem,
831                                         gradargs); CeedChk(ierr);
832     } else if (dim == 3) {
833       CeedInt elemsPerBlock = 1;
834       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
835                                              ? 1 : 0 );
836       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
837       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d,
838                                         ncomp*elemsPerBlock, sharedMem,
839                                         gradargs); CeedChk(ierr);
840     }
841   } break;
842   case CEED_EVAL_WEIGHT: {
843     CeedInt Q1d;
844     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
845     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
846     if (dim == 1) {
847       const CeedInt elemsPerBlock = 32/Q1d;
848       const CeedInt gridsize = nelem/elemsPerBlock + ( (
849                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
850       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d,
851                                   elemsPerBlock, 1, weightargs);
852       CeedChk(ierr);
853     } else if (dim == 2) {
854       const CeedInt optElems = 32/(Q1d*Q1d);
855       const CeedInt elemsPerBlock = optElems>0?optElems:1;
856       const CeedInt gridsize = nelem/elemsPerBlock + ( (
857                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
858       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d,
859                                   elemsPerBlock, weightargs);
860       CeedChk(ierr);
861     } else if (dim == 3) {
862       const CeedInt gridsize = nelem;
863       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
864                                   weightargs);
865       CeedChk(ierr);
866     }
867   } break;
868   // LCOV_EXCL_START
869   // Evaluate the divergence to/from the quadrature points
870   case CEED_EVAL_DIV:
871     return CeedError(ceed, 1, "CEED_EVAL_DIV not supported");
872   // Evaluate the curl to/from the quadrature points
873   case CEED_EVAL_CURL:
874     return CeedError(ceed, 1, "CEED_EVAL_CURL not supported");
875   // Take no action, BasisApply should not have been called
876   case CEED_EVAL_NONE:
877     return CeedError(ceed, 1,
878                      "CEED_EVAL_NONE does not make sense in this context");
879     // LCOV_EXCL_STOP
880   }
881 
882   // Restore vectors
883   if (emode != CEED_EVAL_WEIGHT) {
884     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
885   }
886   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
887   return 0;
888 }
889 
890 //------------------------------------------------------------------------------
891 // Destroy basis
892 //------------------------------------------------------------------------------
893 static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
894   int ierr;
895   Ceed ceed;
896   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
897 
898   CeedBasis_Cuda_shared *data;
899   ierr = CeedBasisGetData(basis, (void *) &data); CeedChk(ierr);
900 
901   CeedChk_Cu(ceed, cuModuleUnload(data->module));
902 
903   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr);
904   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr);
905   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr);
906   ierr = cudaFree(data->d_collograd1d); CeedChk_Cu(ceed, ierr);
907 
908   ierr = CeedFree(&data); CeedChk(ierr);
909 
910   return 0;
911 }
912 
913 //------------------------------------------------------------------------------
914 // Create tensor basis
915 //------------------------------------------------------------------------------
916 int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
917                                         const CeedScalar *interp1d,
918                                         const CeedScalar *grad1d,
919                                         const CeedScalar *qref1d,
920                                         const CeedScalar *qweight1d,
921                                         CeedBasis basis) {
922   int ierr;
923   Ceed ceed;
924   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
925   if (Q1d<P1d) {
926     return CeedError(ceed, 1, "Backend does not implement underintegrated basis.");
927   }
928   CeedBasis_Cuda_shared *data;
929   ierr = CeedCalloc(1, &data); CeedChk(ierr);
930 
931   // Copy basis data to GPU
932   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
933   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr);
934   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
935                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
936 
937   const CeedInt iBytes = qBytes * P1d;
938   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr);
939   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
940                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
941 
942   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr);
943   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
944                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
945 
946   // Compute collocated gradient and copy to GPU
947   data->d_collograd1d = NULL;
948   if (dim == 3 && Q1d >= P1d) {
949     CeedScalar *collograd1d;
950     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChk(ierr);
951     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChk(ierr);
952     ierr = cudaMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
953     CeedChk_Cu(ceed, ierr);
954     ierr = cudaMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
955                       cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
956     ierr = CeedFree(&collograd1d); CeedChk(ierr);
957   }
958 
959   // Compile basis kernels
960   CeedInt ncomp;
961   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
962   ierr = CeedCompileCuda(ceed, kernelsShared, &data->module, 7,
963                          "Q1D", Q1d,
964                          "P1D", P1d,
965                          "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
966                              Q1d : P1d, dim),
967                          "BASIS_DIM", dim,
968                          "BASIS_NCOMP", ncomp,
969                          "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
970                          "BASIS_NQPT", CeedIntPow(Q1d, dim)
971                         ); CeedChk(ierr);
972   ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp);
973   CeedChk(ierr);
974   ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad);
975   CeedChk(ierr);
976   ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight);
977   CeedChk(ierr);
978 
979   ierr = CeedBasisSetData(basis, (void *)&data); CeedChk(ierr);
980 
981   // Register backend functions
982   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
983                                 CeedBasisApplyTensor_Cuda_shared);
984   CeedChk(ierr);
985   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
986                                 CeedBasisDestroy_Cuda_shared); CeedChk(ierr);
987   return 0;
988 }
989 //------------------------------------------------------------------------------
990