xref: /libCEED/backends/cuda-shared/ceed-cuda-shared-basis.c (revision 1d0137909cd290d677dbca28a6953e6505c9d054)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include "ceed-cuda-shared.h"
18 
19 //------------------------------------------------------------------------------
20 // Shared mem kernels
21 //------------------------------------------------------------------------------
22 // *INDENT-OFF*
23 static const char *kernelsShared = QUOTE(
24 
25 //------------------------------------------------------------------------------
26 // Sum input into output
27 //------------------------------------------------------------------------------
28 inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
29   for (int i = 0; i < P1D; i++)
30     r_V[i] += r_U[i];
31 }
32 
33 //------------------------------------------------------------------------------
34 // 1D
35 //------------------------------------------------------------------------------
36 
37 //------------------------------------------------------------------------------
38 // Read DoFs
39 //------------------------------------------------------------------------------
40 inline __device__ void readDofs1d(const int elem, const int tidx,
41                                   const int tidy, const int tidz,const int comp,
42                                   const int nelem, const CeedScalar *d_U,
43                                   CeedScalar *slice) {
44   for (int i = 0; i < P1D; i++)
45     slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem];
46   for (int i = P1D; i < Q1D; i++)
47     slice[i + tidz*T1D] = 0.0;
48 }
49 
50 //------------------------------------------------------------------------------
51 // Write DoFs
52 //------------------------------------------------------------------------------
53 inline __device__ void writeDofs1d(const int elem, const int tidx,
54                                    const int tidy, const int comp,
55                                    const int nelem, const CeedScalar &r_V,
56                                    CeedScalar *d_V) {
57   if (tidx<P1D)
58     d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V;
59 }
60 
61 //------------------------------------------------------------------------------
62 // Read quadrature point data
63 //------------------------------------------------------------------------------
64 inline __device__ void readQuads1d(const int elem, const int tidx,
65                                    const int tidy, const int tidz, const int comp,
66                                    const int dim, const int nelem,
67                                    const CeedScalar *d_U, CeedScalar *slice) {
68   for (int i = 0; i < Q1D; i++)
69     slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
70                             dim*BASIS_NCOMP*nelem*Q1D];
71   for (int i = Q1D; i < P1D; i++)
72     slice[i + tidz*T1D] = 0.0;
73 }
74 
75 //------------------------------------------------------------------------------
76 // Write quadrature point data
77 //------------------------------------------------------------------------------
78 inline __device__ void writeQuads1d(const int elem, const int tidx,
79                                     const int tidy, const int comp,
80                                     const int dim, const int nelem,
81                                     const CeedScalar &r_V, CeedScalar *d_V) {
82   if (tidx<Q1D)
83     d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
84 }
85 
86 //------------------------------------------------------------------------------
87 // 1D tensor contraction
88 //------------------------------------------------------------------------------
89 inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
90                                    const int tidy, const int tidz,
91                                    const CeedScalar &U, const CeedScalar *B,
92                                    CeedScalar &V) {
93   V = 0.0;
94   for (int i = 0; i < P1D; ++i)
95     V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction
96 }
97 
98 //------------------------------------------------------------------------------
99 // 1D transpose tensor contraction
100 //------------------------------------------------------------------------------
101 inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
102     const int tidy, const int tidz,
103     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
104   V = 0.0;
105   for (int i = 0; i < Q1D; ++i)
106     V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction
107 }
108 
109 //------------------------------------------------------------------------------
110 // 1D interpolate to quadrature points
111 //------------------------------------------------------------------------------
112 inline __device__ void interp1d(const CeedInt nelem, const int transpose,
113                                 const CeedScalar *c_B,
114                                 const CeedScalar *__restrict__ d_U,
115                                 CeedScalar *__restrict__ d_V,
116                                 CeedScalar *slice) {
117   CeedScalar r_V;
118   CeedScalar r_t;
119 
120   const int tidx = threadIdx.x;
121   const int tidy = threadIdx.y;
122   const int tidz = threadIdx.z;
123 
124 
125   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
126        elem += gridDim.x*blockDim.z) {
127     for (int comp = 0; comp < BASIS_NCOMP; comp++) {
128       if (!transpose) {
129         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
130         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
131         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
132       } else {
133         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
134         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
135         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
136       }
137     }
138   }
139 }
140 
141 //------------------------------------------------------------------------------
142 // 1D derivatives at quadrature points
143 //------------------------------------------------------------------------------
144 inline __device__ void grad1d(const CeedInt nelem, const int transpose,
145                               const CeedScalar *c_B, const CeedScalar *c_G,
146                               const CeedScalar *__restrict__ d_U,
147                               CeedScalar *__restrict__ d_V,
148                               CeedScalar *slice) {
149   CeedScalar r_U;
150   CeedScalar r_V;
151 
152   const int tidx = threadIdx.x;
153   const int tidy = threadIdx.y;
154   const int tidz = threadIdx.z;
155   int dim;
156 
157   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
158        elem += gridDim.x*blockDim.z) {
159     for(int comp = 0; comp < BASIS_NCOMP; comp++) {
160       if (!transpose) {
161         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
162         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
163         dim = 0;
164         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
165       } else {
166         dim = 0;
167         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
168         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
169         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
170       }
171     }
172   }
173 }
174 
175 //------------------------------------------------------------------------------
176 // 1D Quadrature weights
177 //------------------------------------------------------------------------------
178 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
179                          CeedScalar *w) {
180   const int tid = threadIdx.x;
181   const CeedScalar weight = qweight1d[tid];
182   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
183        elem += gridDim.x*blockDim.y) {
184     const int ind = elem*Q1D + tid;
185     w[ind] = weight;
186   }
187 }
188 
189 //------------------------------------------------------------------------------
190 // 2D
191 //------------------------------------------------------------------------------
192 
193 //------------------------------------------------------------------------------
194 // Read DoFs
195 //------------------------------------------------------------------------------
196 inline __device__ void readDofs2d(const int elem, const int tidx,
197                                   const int tidy, const int comp,
198                                   const int nelem, const CeedScalar *d_U,
199                                   CeedScalar &U) {
200   U = (tidx<P1D && tidy<P1D) ?
201       d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0;
202 }
203 
204 //------------------------------------------------------------------------------
205 // Write DoFs
206 //------------------------------------------------------------------------------
207 inline __device__ void writeDofs2d(const int elem, const int tidx,
208                                    const int tidy, const int comp,
209                                    const int nelem, const CeedScalar &r_V,
210                                    CeedScalar *d_V) {
211   if (tidx<P1D && tidy<P1D)
212     d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V;
213 }
214 
215 //------------------------------------------------------------------------------
216 // Read quadrature point data
217 //------------------------------------------------------------------------------
218 inline __device__ void readQuads2d(const int elem, const int tidx,
219                                    const int tidy, const int comp,
220                                    const int dim, const int nelem,
221                                    const CeedScalar *d_U, CeedScalar &U ) {
222   U = (tidx<Q1D && tidy<Q1D) ?
223       d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
224       dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0;
225 }
226 
227 //------------------------------------------------------------------------------
228 // Write quadrature point data
229 //------------------------------------------------------------------------------
230 inline __device__ void writeQuads2d(const int elem, const int tidx,
231                                     const int tidy, const int comp,
232                                     const int dim, const int nelem,
233                                     const CeedScalar &r_V, CeedScalar *d_V) {
234   if (tidx<Q1D && tidy<Q1D)
235     d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
236     dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V;
237 }
238 
239 //------------------------------------------------------------------------------
240 // 2D tensor contraction x
241 //------------------------------------------------------------------------------
242 inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
243                                    const int tidy, const int tidz,
244                                    const CeedScalar &U, const CeedScalar *B,
245                                    CeedScalar &V) {
246   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
247   __syncthreads();
248   V = 0.0;
249   if (tidx < Q1D)
250     for (int i = 0; i < P1D; ++i)
251       V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
252   __syncthreads();
253 }
254 
255 //------------------------------------------------------------------------------
256 // 2D tensor contraction y
257 //------------------------------------------------------------------------------
258 inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
259                                    const int tidy, const int tidz,
260                                    const CeedScalar &U, const CeedScalar *B,
261                                    CeedScalar &V) {
262   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
263   __syncthreads();
264   V = 0.0;
265   if (tidy < Q1D)
266     for (int i = 0; i < P1D; ++i)
267       V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
268   __syncthreads();
269 }
270 
271 //------------------------------------------------------------------------------
272 // 2D transpose tensor contraction y
273 //------------------------------------------------------------------------------
274 inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
275     const int tidy, const int tidz,
276     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
277   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
278   __syncthreads();
279   V = 0.0;
280   if (tidy < P1D)
281     for (int i = 0; i < Q1D; ++i)
282       V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
283   __syncthreads();
284 }
285 
286 //------------------------------------------------------------------------------
287 // 2D transpose tensor contraction x
288 //------------------------------------------------------------------------------
289 inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
290     const int tidy, const int tidz,
291     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
292   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
293   __syncthreads();
294   V = 0.0;
295   if (tidx < P1D)
296     for (int i = 0; i < Q1D; ++i)
297       V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
298   __syncthreads();
299 }
300 
301 //------------------------------------------------------------------------------
302 // 2D interpolate to quadrature points
303 //------------------------------------------------------------------------------
304 inline __device__ void interp2d(const CeedInt nelem, const int transpose,
305                                 const CeedScalar *c_B,
306                                 const CeedScalar *__restrict__ d_U,
307                                 CeedScalar *__restrict__ d_V,
308                                 CeedScalar *slice) {
309   CeedScalar r_V;
310   CeedScalar r_t;
311 
312   const int tidx = threadIdx.x;
313   const int tidy = threadIdx.y;
314   const int tidz = threadIdx.z;
315   const int blockElem = tidz/BASIS_NCOMP;
316   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
317   const int comp = tidz%BASIS_NCOMP;
318 
319   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
320        elem += gridDim.x*elemsPerBlock) {
321     const int comp = tidz%BASIS_NCOMP;
322     r_V = 0.0;
323     r_t = 0.0;
324     if (!transpose) {
325       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
326       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
327       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
328       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
329     } else {
330       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
331       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
332       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
333       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
334     }
335   }
336 }
337 
338 //------------------------------------------------------------------------------
339 // 2D derivatives at quadrature points
340 //------------------------------------------------------------------------------
341 inline __device__ void grad2d(const CeedInt nelem, const int transpose,
342                               const CeedScalar *c_B, const CeedScalar *c_G,
343                               const CeedScalar *__restrict__ d_U,
344                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
345   CeedScalar r_U;
346   CeedScalar r_V;
347   CeedScalar r_t;
348 
349   const int tidx = threadIdx.x;
350   const int tidy = threadIdx.y;
351   const int tidz = threadIdx.z;
352   const int blockElem = tidz/BASIS_NCOMP;
353   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
354   const int comp = tidz%BASIS_NCOMP;
355   int dim;
356 
357   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
358        elem += gridDim.x*elemsPerBlock) {
359     if (!transpose) {
360       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
361       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
362       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
363       dim = 0;
364       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
365       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
366       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
367       dim = 1;
368       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
369     } else {
370       dim = 0;
371       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
372       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
373       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
374       dim = 1;
375       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
376       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
377       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
378       r_V += r_U;
379       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
380     }
381   }
382 }
383 
384 //------------------------------------------------------------------------------
385 // 2D quadrature weights
386 //------------------------------------------------------------------------------
387 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
388                          CeedScalar *w) {
389   const int i = threadIdx.x;
390   const int j = threadIdx.y;
391   const CeedScalar weight = qweight1d[i]*qweight1d[j];
392   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
393        elem += gridDim.x*blockDim.z) {
394     const int ind = elem*Q1D*Q1D + i + j*Q1D;
395     w[ind] = weight;
396   }
397 }
398 
399 //------------------------------------------------------------------------------
400 // 3D
401 //------------------------------------------------------------------------------
402 
403 //------------------------------------------------------------------------------
404 // Read DoFs
405 //------------------------------------------------------------------------------
406 inline __device__ void readDofs3d(const int elem, const int tidx,
407                                   const int tidy, const int comp,
408                                   const int nelem, const CeedScalar *d_U,
409                                   CeedScalar *r_U) {
410   for (int i = 0; i < P1D; i++)
411     r_U[i] = (tidx < P1D && tidy < P1D) ?
412               d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
413                   comp*P1D*P1D*P1D*nelem] : 0.0;
414   for (int i = P1D; i < Q1D; i++)
415     r_U[i] = 0.0;
416 }
417 
418 //------------------------------------------------------------------------------
419 // Write DoFs
420 //------------------------------------------------------------------------------
421 inline __device__ void writeDofs3d(const int elem, const int tidx,
422                                    const int tidy, const int comp,
423                                    const int nelem, const CeedScalar *r_V,
424                                    CeedScalar *d_V) {
425   if (tidx < P1D && tidy < P1D) {
426     for (int i = 0; i < P1D; i++)
427       d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
428           comp*P1D*P1D*P1D*nelem] = r_V[i];
429   }
430 }
431 
432 //------------------------------------------------------------------------------
433 // Read quadrature point data
434 //------------------------------------------------------------------------------
435 inline __device__ void readQuads3d(const int elem, const int tidx,
436                                    const int tidy, const int comp,
437                                    const int dim, const int nelem,
438                                    const CeedScalar *d_U, CeedScalar *r_U) {
439   for (int i = 0; i < Q1D; i++)
440     r_U[i] = (tidx < Q1D && tidy < Q1D) ?
441               d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
442               comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0;
443   for (int i = Q1D; i < P1D; i++)
444     r_U[i] = 0.0;
445 }
446 
447 //------------------------------------------------------------------------------
448 // Write quadrature point data
449 //------------------------------------------------------------------------------
450 inline __device__ void writeQuads3d(const int elem, const int tidx,
451                                     const int tidy, const int comp,
452                                     const int dim, const int nelem,
453                                     const CeedScalar *r_V, CeedScalar *d_V) {
454   if (tidx < Q1D && tidy < Q1D) {
455     for (int i = 0; i < Q1D; i++)
456       d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
457           dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i];
458   }
459 }
460 
461 //------------------------------------------------------------------------------
462 // 3D tensor contract x
463 //------------------------------------------------------------------------------
464 inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
465                                    const int tidy, const int tidz,
466                                    const CeedScalar *U,
467                                    const CeedScalar *B,
468                                    CeedScalar *V) {
469   for (int k = 0; k < P1D; ++k) {
470     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
471     __syncthreads();
472     V[k] = 0.0;
473     if (tidx < Q1D && tidy < P1D)
474       for (int i = 0; i < P1D; ++i)
475         V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
476     __syncthreads();
477   }
478 }
479 
480 //------------------------------------------------------------------------------
481 // 3D tensor contract y
482 //------------------------------------------------------------------------------
483 inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
484                                    const int tidy, const int tidz,
485                                    const CeedScalar *U,
486                                    const CeedScalar *B,
487                                    CeedScalar *V) {
488   for (int k = 0; k < P1D; ++k) {
489     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
490     __syncthreads();
491     V[k] = 0.0;
492     if (tidx < Q1D && tidy < Q1D)
493       for (int i = 0; i < P1D; ++i)
494         V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
495     __syncthreads();
496   }
497 }
498 
499 //------------------------------------------------------------------------------
500 // 3D tensor contract z
501 //------------------------------------------------------------------------------
502 inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
503                                    const int tidy, const int tidz,
504                                    const CeedScalar *U,
505                                    const CeedScalar *B,
506                                    CeedScalar *V) {
507   for (int k = 0; k < Q1D; ++k) {
508     V[k] = 0.0;
509     if (tidx < Q1D && tidy < Q1D)
510       for (int i = 0; i < P1D; ++i)
511         V[k] += B[i + k*P1D] * U[i]; // Contract z direction
512   }
513   for (int k = Q1D; k < P1D; ++k)
514     V[k] = 0.0;
515 }
516 
517 //------------------------------------------------------------------------------
518 // 3D transpose tensor contract z
519 //------------------------------------------------------------------------------
520 inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
521                                             const int tidy, const int tidz,
522                                             const CeedScalar *U,
523                                             const CeedScalar *B,
524                                             CeedScalar *V) {
525   for (int k = 0; k < P1D; ++k) {
526     V[k] = 0.0;
527     if (tidx < Q1D && tidy < Q1D)
528       for (int i = 0; i < Q1D; ++i)
529         V[k] += B[k + i*P1D] * U[i]; // Contract z direction
530   }
531   for (int k = P1D; k < Q1D; ++k)
532     V[k] = 0.0;
533 }
534 
535 //------------------------------------------------------------------------------
536 // 3D transpose tensor contract y
537 //------------------------------------------------------------------------------
538 inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
539                                             const int tidy, const int tidz,
540                                             const CeedScalar *U,
541                                             const CeedScalar *B,
542                                             CeedScalar *V) {
543   for (int k = 0; k < P1D; ++k) {
544     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
545     __syncthreads();
546     V[k] = 0.0;
547     if (tidx < Q1D && tidy < P1D)
548       for (int i = 0; i < Q1D; ++i)
549         V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
550     __syncthreads();
551   }
552 }
553 
554 //------------------------------------------------------------------------------
555 // 3D transpose tensor contract x
556 //------------------------------------------------------------------------------
557 inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
558                                             const int tidy, const int tidz,
559                                             const CeedScalar *U,
560                                             const CeedScalar *B,
561                                             CeedScalar *V) {
562   for (int k = 0; k < P1D; ++k) {
563     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
564     __syncthreads();
565     V[k] = 0.0;
566     if (tidx < P1D && tidy < P1D)
567       for (int i = 0; i < Q1D; ++i)
568         V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
569     __syncthreads();
570   }
571 }
572 
573 //------------------------------------------------------------------------------
574 // 3D interpolate to quadrature points
575 //------------------------------------------------------------------------------
576 inline __device__ void interp3d(const CeedInt nelem, const int transpose,
577                                 const CeedScalar *c_B,
578                                 const CeedScalar *__restrict__ d_U,
579                                 CeedScalar *__restrict__ d_V,
580                                 CeedScalar *slice) {
581   CeedScalar r_V[T1D];
582   CeedScalar r_t[T1D];
583 
584   const int tidx = threadIdx.x;
585   const int tidy = threadIdx.y;
586   const int tidz = threadIdx.z;
587   const int blockElem = tidz/BASIS_NCOMP;
588   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
589   const int comp = tidz%BASIS_NCOMP;
590 
591   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
592        elem += gridDim.x*elemsPerBlock) {
593     for (int i = 0; i < T1D; ++i) {
594       r_V[i] = 0.0;
595       r_t[i] = 0.0;
596     }
597     if (!transpose) {
598       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
599       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
600       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
601       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
602       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
603     } else {
604       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
605       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
606       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
607       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
608       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
609     }
610   }
611 }
612 
613 //------------------------------------------------------------------------------
614 // 3D derivatives at quadrature points
615 //------------------------------------------------------------------------------
616 inline __device__ void grad3d(const CeedInt nelem, const int transpose,
617                               const CeedScalar *c_B, const CeedScalar *c_G,
618                               const CeedScalar *__restrict__ d_U,
619                               CeedScalar *__restrict__ d_V,
620                               CeedScalar *slice) {
621   // Use P1D for one of these
622   CeedScalar r_U[T1D];
623   CeedScalar r_V[T1D];
624   CeedScalar r_t[T1D];
625 
626   const int tidx = threadIdx.x;
627   const int tidy = threadIdx.y;
628   const int tidz = threadIdx.z;
629   const int blockElem = tidz/BASIS_NCOMP;
630   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
631   const int comp = tidz%BASIS_NCOMP;
632   int dim;
633 
634   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
635        elem += gridDim.x*elemsPerBlock) {
636     for (int i = 0; i < T1D; ++i) {
637       r_U[i] = 0.0;
638       r_V[i] = 0.0;
639       r_t[i] = 0.0;
640     }
641     if (!transpose) {
642       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
643       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
644       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
645       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
646       dim = 0;
647       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
648       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
649       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
650       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
651       dim = 1;
652       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
653       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
654       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
655       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
656       dim = 2;
657       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
658     } else {
659       dim = 0;
660       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
661       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
662       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
663       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
664       dim = 1;
665       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
666       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
667       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
668       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
669       add(r_V, r_t);
670       dim = 2;
671       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
672       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
673       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
674       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
675       add(r_V, r_t);
676       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
677     }
678   }
679 }
680 
681 //------------------------------------------------------------------------------
682 // 3D quadrature weights
683 //------------------------------------------------------------------------------
684 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
685                          CeedScalar *w) {
686   const int i = threadIdx.x;
687   const int j = threadIdx.y;
688   const int k = threadIdx.z;
689   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
690   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
691     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
692     w[ind] = weight;
693   }
694 }
695 
696 
697 //------------------------------------------------------------------------------
698 // Basis kernels
699 //------------------------------------------------------------------------------
700 
701 //------------------------------------------------------------------------------
702 // Interp kernel by dim
703 //------------------------------------------------------------------------------
704 extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
705                                   const CeedScalar *c_B,
706                                   const CeedScalar *__restrict__ d_U,
707                                   CeedScalar *__restrict__ d_V) {
708   extern __shared__ double slice[];
709   if (BASIS_DIM == 1) {
710     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
711   } else if (BASIS_DIM == 2) {
712     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
713   } else if (BASIS_DIM == 3) {
714     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
715   }
716 }
717 
718 //------------------------------------------------------------------------------
719 // Grad kernel by dim
720 //------------------------------------------------------------------------------
721 extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
722                                 const CeedScalar *c_B, const CeedScalar *c_G,
723                                 const CeedScalar *__restrict__ d_U,
724                                 CeedScalar *__restrict__ d_V) {
725   extern __shared__ double slice[];
726   if (BASIS_DIM == 1) {
727     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
728   } else if (BASIS_DIM == 2) {
729     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
730   } else if (BASIS_DIM == 3) {
731     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
732   }
733 }
734 
735 //------------------------------------------------------------------------------
736 // Weight kernels by dim
737 //------------------------------------------------------------------------------
738 extern "C" __global__ void weight(const CeedInt nelem,
739                                   const CeedScalar *__restrict__ qweight1d,
740                                   CeedScalar *__restrict__ v) {
741   if (BASIS_DIM == 1) {
742     weight1d(nelem, qweight1d, v);
743   } else if (BASIS_DIM == 2) {
744     weight2d(nelem, qweight1d, v);
745   } else if (BASIS_DIM == 3) {
746     weight3d(nelem, qweight1d, v);
747   }
748 }
749 
750 );
751 // *INDENT-ON*
752 
753 //------------------------------------------------------------------------------
754 // Device initalization
755 //------------------------------------------------------------------------------
756 int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
757                        CeedScalar **c_B);
758 int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
759                            CeedInt Q1d, CeedScalar **c_B_ptr,
760                            CeedScalar **c_G_ptr);
761 
762 //------------------------------------------------------------------------------
763 // Apply basis
764 //------------------------------------------------------------------------------
765 int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem,
766                                      CeedTransposeMode tmode,
767                                      CeedEvalMode emode, CeedVector u,
768                                      CeedVector v) {
769   int ierr;
770   Ceed ceed;
771   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
772   Ceed_Cuda_shared *ceed_Cuda;
773   CeedGetData(ceed, &ceed_Cuda); CeedChk(ierr);
774   CeedBasis_Cuda_shared *data;
775   CeedBasisGetData(basis, &data); CeedChk(ierr);
776   const CeedInt transpose = tmode == CEED_TRANSPOSE;
777   CeedInt dim, ncomp;
778   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
779   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
780 
781   // Read vectors
782   const CeedScalar *d_u;
783   CeedScalar *d_v;
784   if (emode != CEED_EVAL_WEIGHT) {
785     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
786   }
787   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
788 
789   // Clear v for transpose mode
790   if (tmode == CEED_TRANSPOSE) {
791     CeedInt length;
792     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
793     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
794   }
795 
796   // Apply basis operation
797   switch (emode) {
798   case CEED_EVAL_INTERP: {
799     CeedInt P1d, Q1d;
800     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
801     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
802     CeedInt thread1d = CeedIntMax(Q1d, P1d);
803     ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
804     CeedChk(ierr);
805     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
806                           &d_u, &d_v
807                          };
808     if (dim == 1) {
809       CeedInt elemsPerBlock = 32;
810       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
811                                              ? 1 : 0 );
812       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
813       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, 1,
814                                         elemsPerBlock, sharedMem,
815                                         interpargs); CeedChk(ierr);
816     } else if (dim == 2) {
817       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
818       // elemsPerBlock must be at least 1
819       CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1);
820       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
821                                              ? 1 : 0 );
822       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
823       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, thread1d,
824                                         ncomp*elemsPerBlock, sharedMem,
825                                         interpargs); CeedChk(ierr);
826     } else if (dim == 3) {
827       CeedInt elemsPerBlock = 1;
828       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
829                                              ? 1 : 0 );
830       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
831       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, thread1d,
832                                         ncomp*elemsPerBlock, sharedMem,
833                                         interpargs); CeedChk(ierr);
834     }
835   } break;
836   case CEED_EVAL_GRAD: {
837     CeedInt P1d, Q1d;
838     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
839     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
840     CeedInt thread1d = CeedIntMax(Q1d, P1d);
841     ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
842                                   Q1d, &data->c_B, &data->c_G);
843     CeedChk(ierr);
844     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
845                         &data->c_G, &d_u, &d_v
846                        };
847     if (dim == 1) {
848       CeedInt elemsPerBlock = 32;
849       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
850                                              ? 1 : 0 );
851       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
852       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, 1,
853                                         elemsPerBlock, sharedMem, gradargs);
854       CeedChk(ierr);
855     } else if (dim == 2) {
856       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
857       // elemsPerBlock must be at least 1
858       CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1);
859       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
860                                              ? 1 : 0 );
861       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
862       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, thread1d,
863                                         ncomp*elemsPerBlock, sharedMem,
864                                         gradargs); CeedChk(ierr);
865     } else if (dim == 3) {
866       CeedInt elemsPerBlock = 1;
867       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
868                                              ? 1 : 0 );
869       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
870       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, thread1d,
871                                         ncomp*elemsPerBlock, sharedMem,
872                                         gradargs); CeedChk(ierr);
873     }
874   } break;
875   case CEED_EVAL_WEIGHT: {
876     CeedInt Q1d;
877     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
878     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
879     if (dim == 1) {
880       const CeedInt elemsPerBlock = 32/Q1d;
881       const CeedInt gridsize = nelem/elemsPerBlock + ( (
882                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
883       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d,
884                                   elemsPerBlock, 1, weightargs);
885       CeedChk(ierr);
886     } else if (dim == 2) {
887       const CeedInt optElems = 32/(Q1d*Q1d);
888       const CeedInt elemsPerBlock = optElems>0?optElems:1;
889       const CeedInt gridsize = nelem/elemsPerBlock + ( (
890                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
891       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d,
892                                   elemsPerBlock, weightargs);
893       CeedChk(ierr);
894     } else if (dim == 3) {
895       const CeedInt gridsize = nelem;
896       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
897                                   weightargs);
898       CeedChk(ierr);
899     }
900   } break;
901   // LCOV_EXCL_START
902   // Evaluate the divergence to/from the quadrature points
903   case CEED_EVAL_DIV:
904     return CeedError(ceed, 1, "CEED_EVAL_DIV not supported");
905   // Evaluate the curl to/from the quadrature points
906   case CEED_EVAL_CURL:
907     return CeedError(ceed, 1, "CEED_EVAL_CURL not supported");
908   // Take no action, BasisApply should not have been called
909   case CEED_EVAL_NONE:
910     return CeedError(ceed, 1,
911                      "CEED_EVAL_NONE does not make sense in this context");
912     // LCOV_EXCL_STOP
913   }
914 
915   // Restore vectors
916   if (emode != CEED_EVAL_WEIGHT) {
917     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
918   }
919   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
920   return 0;
921 }
922 
923 //------------------------------------------------------------------------------
924 // Destroy basis
925 //------------------------------------------------------------------------------
926 static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
927   int ierr;
928   Ceed ceed;
929   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
930 
931   CeedBasis_Cuda_shared *data;
932   ierr = CeedBasisGetData(basis, &data); CeedChk(ierr);
933 
934   CeedChk_Cu(ceed, cuModuleUnload(data->module));
935 
936   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr);
937   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr);
938   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr);
939   ierr = cudaFree(data->d_collograd1d); CeedChk_Cu(ceed, ierr);
940 
941   ierr = CeedFree(&data); CeedChk(ierr);
942 
943   return 0;
944 }
945 
946 //------------------------------------------------------------------------------
947 // Create tensor basis
948 //------------------------------------------------------------------------------
949 int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
950                                         const CeedScalar *interp1d,
951                                         const CeedScalar *grad1d,
952                                         const CeedScalar *qref1d,
953                                         const CeedScalar *qweight1d,
954                                         CeedBasis basis) {
955   int ierr;
956   Ceed ceed;
957   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
958   CeedBasis_Cuda_shared *data;
959   ierr = CeedCalloc(1, &data); CeedChk(ierr);
960 
961   // Copy basis data to GPU
962   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
963   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr);
964   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
965                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
966 
967   const CeedInt iBytes = qBytes * P1d;
968   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr);
969   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
970                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
971 
972   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr);
973   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
974                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
975 
976   // Compute collocated gradient and copy to GPU
977   data->d_collograd1d = NULL;
978   if (dim == 3 && Q1d >= P1d) {
979     CeedScalar *collograd1d;
980     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChk(ierr);
981     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChk(ierr);
982     ierr = cudaMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
983     CeedChk_Cu(ceed, ierr);
984     ierr = cudaMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
985                       cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
986     ierr = CeedFree(&collograd1d); CeedChk(ierr);
987   }
988 
989   // Compile basis kernels
990   CeedInt ncomp;
991   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
992   ierr = CeedCompileCuda(ceed, kernelsShared, &data->module, 8,
993                          "Q1D", Q1d,
994                          "P1D", P1d,
995                          "T1D", CeedIntMax(Q1d, P1d),
996                          "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
997                              Q1d : P1d, dim),
998                          "BASIS_DIM", dim,
999                          "BASIS_NCOMP", ncomp,
1000                          "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
1001                          "BASIS_NQPT", CeedIntPow(Q1d, dim)
1002                         ); CeedChk(ierr);
1003   ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp);
1004   CeedChk(ierr);
1005   ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad);
1006   CeedChk(ierr);
1007   ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight);
1008   CeedChk(ierr);
1009 
1010   ierr = CeedBasisSetData(basis, data); CeedChk(ierr);
1011 
1012   // Register backend functions
1013   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
1014                                 CeedBasisApplyTensor_Cuda_shared);
1015   CeedChk(ierr);
1016   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
1017                                 CeedBasisDestroy_Cuda_shared); CeedChk(ierr);
1018   return 0;
1019 }
1020 //------------------------------------------------------------------------------
1021