xref: /libCEED/backends/cuda-shared/ceed-cuda-shared-basis.c (revision dc8efd83546faf0200bf0bfcfb1678fae1874cc5)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed.h>
18 #include <ceed-backend.h>
19 #include <cuda.h>
20 #include <cuda_runtime.h>
21 #include <stddef.h>
22 #include "ceed-cuda-shared.h"
23 #include "../cuda/ceed-cuda.h"
24 
25 //------------------------------------------------------------------------------
26 // Shared mem kernels
27 //------------------------------------------------------------------------------
28 // *INDENT-OFF*
29 static const char *kernelsShared = QUOTE(
30 
31 //------------------------------------------------------------------------------
32 // Sum input into output
33 //------------------------------------------------------------------------------
34 inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
35   for (int i = 0; i < P1D; i++)
36     r_V[i] += r_U[i];
37 }
38 
39 //------------------------------------------------------------------------------
40 // 1D
41 //------------------------------------------------------------------------------
42 
43 //------------------------------------------------------------------------------
44 // Read DoFs
45 //------------------------------------------------------------------------------
46 inline __device__ void readDofs1d(const int elem, const int tidx,
47                                   const int tidy, const int tidz,const int comp,
48                                   const int nelem, const CeedScalar *d_U,
49                                   CeedScalar *slice) {
50   for (int i = 0; i < P1D; i++)
51     slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem];
52   for (int i = P1D; i < Q1D; i++)
53     slice[i + tidz*T1D] = 0.0;
54 }
55 
56 //------------------------------------------------------------------------------
57 // Write DoFs
58 //------------------------------------------------------------------------------
59 inline __device__ void writeDofs1d(const int elem, const int tidx,
60                                    const int tidy, const int comp,
61                                    const int nelem, const CeedScalar &r_V,
62                                    CeedScalar *d_V) {
63   if (tidx<P1D)
64     d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V;
65 }
66 
67 //------------------------------------------------------------------------------
68 // Read quadrature point data
69 //------------------------------------------------------------------------------
70 inline __device__ void readQuads1d(const int elem, const int tidx,
71                                    const int tidy, const int tidz, const int comp,
72                                    const int dim, const int nelem,
73                                    const CeedScalar *d_U, CeedScalar *slice) {
74   for (int i = 0; i < Q1D; i++)
75     slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
76                             dim*BASIS_NCOMP*nelem*Q1D];
77   for (int i = Q1D; i < P1D; i++)
78     slice[i + tidz*T1D] = 0.0;
79 }
80 
81 //------------------------------------------------------------------------------
82 // Write quadrature point data
83 //------------------------------------------------------------------------------
84 inline __device__ void writeQuads1d(const int elem, const int tidx,
85                                     const int tidy, const int comp,
86                                     const int dim, const int nelem,
87                                     const CeedScalar &r_V, CeedScalar *d_V) {
88   if (tidx<Q1D)
89     d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
90 }
91 
92 //------------------------------------------------------------------------------
93 // 1D tensor contraction
94 //------------------------------------------------------------------------------
95 inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
96                                    const int tidy, const int tidz,
97                                    const CeedScalar &U, const CeedScalar *B,
98                                    CeedScalar &V) {
99   V = 0.0;
100   for (int i = 0; i < P1D; ++i)
101     V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction
102 }
103 
104 //------------------------------------------------------------------------------
105 // 1D transpose tensor contraction
106 //------------------------------------------------------------------------------
107 inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
108     const int tidy, const int tidz,
109     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
110   V = 0.0;
111   for (int i = 0; i < Q1D; ++i)
112     V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction
113 }
114 
115 //------------------------------------------------------------------------------
116 // 1D interpolate to quadrature points
117 //------------------------------------------------------------------------------
118 inline __device__ void interp1d(const CeedInt nelem, const int transpose,
119                                 const CeedScalar *c_B,
120                                 const CeedScalar *__restrict__ d_U,
121                                 CeedScalar *__restrict__ d_V,
122                                 CeedScalar *slice) {
123   CeedScalar r_V;
124   CeedScalar r_t;
125 
126   const int tidx = threadIdx.x;
127   const int tidy = threadIdx.y;
128   const int tidz = threadIdx.z;
129 
130 
131   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
132        elem += gridDim.x*blockDim.z) {
133     for (int comp = 0; comp < BASIS_NCOMP; comp++) {
134       if (!transpose) {
135         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
136         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
137         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
138       } else {
139         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
140         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
141         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
142       }
143     }
144   }
145 }
146 
147 //------------------------------------------------------------------------------
148 // 1D derivatives at quadrature points
149 //------------------------------------------------------------------------------
150 inline __device__ void grad1d(const CeedInt nelem, const int transpose,
151                               const CeedScalar *c_B, const CeedScalar *c_G,
152                               const CeedScalar *__restrict__ d_U,
153                               CeedScalar *__restrict__ d_V,
154                               CeedScalar *slice) {
155   CeedScalar r_U;
156   CeedScalar r_V;
157 
158   const int tidx = threadIdx.x;
159   const int tidy = threadIdx.y;
160   const int tidz = threadIdx.z;
161   int dim;
162 
163   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
164        elem += gridDim.x*blockDim.z) {
165     for(int comp = 0; comp < BASIS_NCOMP; comp++) {
166       if (!transpose) {
167         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
168         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
169         dim = 0;
170         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
171       } else {
172         dim = 0;
173         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
174         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
175         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
176       }
177     }
178   }
179 }
180 
181 //------------------------------------------------------------------------------
182 // 1D Quadrature weights
183 //------------------------------------------------------------------------------
184 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
185                          CeedScalar *w) {
186   const int tid = threadIdx.x;
187   const CeedScalar weight = qweight1d[tid];
188   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
189        elem += gridDim.x*blockDim.y) {
190     const int ind = elem*Q1D + tid;
191     w[ind] = weight;
192   }
193 }
194 
195 //------------------------------------------------------------------------------
196 // 2D
197 //------------------------------------------------------------------------------
198 
199 //------------------------------------------------------------------------------
200 // Read DoFs
201 //------------------------------------------------------------------------------
202 inline __device__ void readDofs2d(const int elem, const int tidx,
203                                   const int tidy, const int comp,
204                                   const int nelem, const CeedScalar *d_U,
205                                   CeedScalar &U) {
206   U = (tidx<P1D && tidy<P1D) ?
207       d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0;
208 }
209 
210 //------------------------------------------------------------------------------
211 // Write DoFs
212 //------------------------------------------------------------------------------
213 inline __device__ void writeDofs2d(const int elem, const int tidx,
214                                    const int tidy, const int comp,
215                                    const int nelem, const CeedScalar &r_V,
216                                    CeedScalar *d_V) {
217   if (tidx<P1D && tidy<P1D)
218     d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V;
219 }
220 
221 //------------------------------------------------------------------------------
222 // Read quadrature point data
223 //------------------------------------------------------------------------------
224 inline __device__ void readQuads2d(const int elem, const int tidx,
225                                    const int tidy, const int comp,
226                                    const int dim, const int nelem,
227                                    const CeedScalar *d_U, CeedScalar &U ) {
228   U = (tidx<Q1D && tidy<Q1D) ?
229       d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
230       dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0;
231 }
232 
233 //------------------------------------------------------------------------------
234 // Write quadrature point data
235 //------------------------------------------------------------------------------
236 inline __device__ void writeQuads2d(const int elem, const int tidx,
237                                     const int tidy, const int comp,
238                                     const int dim, const int nelem,
239                                     const CeedScalar &r_V, CeedScalar *d_V) {
240   if (tidx<Q1D && tidy<Q1D)
241     d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
242     dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V;
243 }
244 
245 //------------------------------------------------------------------------------
246 // 2D tensor contraction x
247 //------------------------------------------------------------------------------
248 inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
249                                    const int tidy, const int tidz,
250                                    const CeedScalar &U, const CeedScalar *B,
251                                    CeedScalar &V) {
252   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
253   __syncthreads();
254   V = 0.0;
255   if (tidx < Q1D)
256     for (int i = 0; i < P1D; ++i)
257       V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
258   __syncthreads();
259 }
260 
261 //------------------------------------------------------------------------------
262 // 2D tensor contraction y
263 //------------------------------------------------------------------------------
264 inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
265                                    const int tidy, const int tidz,
266                                    const CeedScalar &U, const CeedScalar *B,
267                                    CeedScalar &V) {
268   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
269   __syncthreads();
270   V = 0.0;
271   if (tidy < Q1D)
272     for (int i = 0; i < P1D; ++i)
273       V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
274   __syncthreads();
275 }
276 
277 //------------------------------------------------------------------------------
278 // 2D transpose tensor contraction y
279 //------------------------------------------------------------------------------
280 inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
281     const int tidy, const int tidz,
282     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
283   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
284   __syncthreads();
285   V = 0.0;
286   if (tidy < P1D)
287     for (int i = 0; i < Q1D; ++i)
288       V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
289   __syncthreads();
290 }
291 
292 //------------------------------------------------------------------------------
293 // 2D transpose tensor contraction x
294 //------------------------------------------------------------------------------
295 inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
296     const int tidy, const int tidz,
297     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
298   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
299   __syncthreads();
300   V = 0.0;
301   if (tidx < P1D)
302     for (int i = 0; i < Q1D; ++i)
303       V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
304   __syncthreads();
305 }
306 
307 //------------------------------------------------------------------------------
308 // 2D interpolate to quadrature points
309 //------------------------------------------------------------------------------
310 inline __device__ void interp2d(const CeedInt nelem, const int transpose,
311                                 const CeedScalar *c_B,
312                                 const CeedScalar *__restrict__ d_U,
313                                 CeedScalar *__restrict__ d_V,
314                                 CeedScalar *slice) {
315   CeedScalar r_V;
316   CeedScalar r_t;
317 
318   const int tidx = threadIdx.x;
319   const int tidy = threadIdx.y;
320   const int tidz = threadIdx.z;
321   const int blockElem = tidz/BASIS_NCOMP;
322   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
323   const int comp = tidz%BASIS_NCOMP;
324 
325   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
326        elem += gridDim.x*elemsPerBlock) {
327     const int comp = tidz%BASIS_NCOMP;
328     r_V = 0.0;
329     r_t = 0.0;
330     if (!transpose) {
331       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
332       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
333       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
334       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
335     } else {
336       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
337       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
338       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
339       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
340     }
341   }
342 }
343 
344 //------------------------------------------------------------------------------
345 // 2D derivatives at quadrature points
346 //------------------------------------------------------------------------------
347 inline __device__ void grad2d(const CeedInt nelem, const int transpose,
348                               const CeedScalar *c_B, const CeedScalar *c_G,
349                               const CeedScalar *__restrict__ d_U,
350                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
351   CeedScalar r_U;
352   CeedScalar r_V;
353   CeedScalar r_t;
354 
355   const int tidx = threadIdx.x;
356   const int tidy = threadIdx.y;
357   const int tidz = threadIdx.z;
358   const int blockElem = tidz/BASIS_NCOMP;
359   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
360   const int comp = tidz%BASIS_NCOMP;
361   int dim;
362 
363   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
364        elem += gridDim.x*elemsPerBlock) {
365     if (!transpose) {
366       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
367       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
368       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
369       dim = 0;
370       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
371       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
372       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
373       dim = 1;
374       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
375     } else {
376       dim = 0;
377       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
378       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
379       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
380       dim = 1;
381       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
382       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
383       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
384       r_V += r_U;
385       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
386     }
387   }
388 }
389 
390 //------------------------------------------------------------------------------
391 // 2D quadrature weights
392 //------------------------------------------------------------------------------
393 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
394                          CeedScalar *w) {
395   const int i = threadIdx.x;
396   const int j = threadIdx.y;
397   const CeedScalar weight = qweight1d[i]*qweight1d[j];
398   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
399        elem += gridDim.x*blockDim.z) {
400     const int ind = elem*Q1D*Q1D + i + j*Q1D;
401     w[ind] = weight;
402   }
403 }
404 
405 //------------------------------------------------------------------------------
406 // 3D
407 //------------------------------------------------------------------------------
408 
409 //------------------------------------------------------------------------------
410 // Read DoFs
411 //------------------------------------------------------------------------------
412 inline __device__ void readDofs3d(const int elem, const int tidx,
413                                   const int tidy, const int comp,
414                                   const int nelem, const CeedScalar *d_U,
415                                   CeedScalar *r_U) {
416   for (int i = 0; i < P1D; i++)
417     r_U[i] = (tidx < P1D && tidy < P1D) ?
418               d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
419                   comp*P1D*P1D*P1D*nelem] : 0.0;
420   for (int i = P1D; i < Q1D; i++)
421     r_U[i] = 0.0;
422 }
423 
424 //------------------------------------------------------------------------------
425 // Write DoFs
426 //------------------------------------------------------------------------------
427 inline __device__ void writeDofs3d(const int elem, const int tidx,
428                                    const int tidy, const int comp,
429                                    const int nelem, const CeedScalar *r_V,
430                                    CeedScalar *d_V) {
431   if (tidx < P1D && tidy < P1D) {
432     for (int i = 0; i < P1D; i++)
433       d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
434           comp*P1D*P1D*P1D*nelem] = r_V[i];
435   }
436 }
437 
438 //------------------------------------------------------------------------------
439 // Read quadrature point data
440 //------------------------------------------------------------------------------
441 inline __device__ void readQuads3d(const int elem, const int tidx,
442                                    const int tidy, const int comp,
443                                    const int dim, const int nelem,
444                                    const CeedScalar *d_U, CeedScalar *r_U) {
445   for (int i = 0; i < Q1D; i++)
446     r_U[i] = (tidx < Q1D && tidy < Q1D) ?
447               d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
448               comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0;
449   for (int i = Q1D; i < P1D; i++)
450     r_U[i] = 0.0;
451 }
452 
453 //------------------------------------------------------------------------------
454 // Write quadrature point data
455 //------------------------------------------------------------------------------
456 inline __device__ void writeQuads3d(const int elem, const int tidx,
457                                     const int tidy, const int comp,
458                                     const int dim, const int nelem,
459                                     const CeedScalar *r_V, CeedScalar *d_V) {
460   if (tidx < Q1D && tidy < Q1D) {
461     for (int i = 0; i < Q1D; i++)
462       d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
463           dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i];
464   }
465 }
466 
467 //------------------------------------------------------------------------------
468 // 3D tensor contract x
469 //------------------------------------------------------------------------------
470 inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
471                                    const int tidy, const int tidz,
472                                    const CeedScalar *U,
473                                    const CeedScalar *B,
474                                    CeedScalar *V) {
475   for (int k = 0; k < P1D; ++k) {
476     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
477     __syncthreads();
478     V[k] = 0.0;
479     if (tidx < Q1D && tidy < P1D)
480       for (int i = 0; i < P1D; ++i)
481         V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
482     __syncthreads();
483   }
484 }
485 
486 //------------------------------------------------------------------------------
487 // 3D tensor contract y
488 //------------------------------------------------------------------------------
489 inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
490                                    const int tidy, const int tidz,
491                                    const CeedScalar *U,
492                                    const CeedScalar *B,
493                                    CeedScalar *V) {
494   for (int k = 0; k < P1D; ++k) {
495     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
496     __syncthreads();
497     V[k] = 0.0;
498     if (tidx < Q1D && tidy < Q1D)
499       for (int i = 0; i < P1D; ++i)
500         V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
501     __syncthreads();
502   }
503 }
504 
505 //------------------------------------------------------------------------------
506 // 3D tensor contract z
507 //------------------------------------------------------------------------------
508 inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
509                                    const int tidy, const int tidz,
510                                    const CeedScalar *U,
511                                    const CeedScalar *B,
512                                    CeedScalar *V) {
513   for (int k = 0; k < Q1D; ++k) {
514     V[k] = 0.0;
515     if (tidx < Q1D && tidy < Q1D)
516       for (int i = 0; i < P1D; ++i)
517         V[k] += B[i + k*P1D] * U[i]; // Contract z direction
518   }
519   for (int k = Q1D; k < P1D; ++k)
520     V[k] = 0.0;
521 }
522 
523 //------------------------------------------------------------------------------
524 // 3D transpose tensor contract z
525 //------------------------------------------------------------------------------
526 inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
527                                             const int tidy, const int tidz,
528                                             const CeedScalar *U,
529                                             const CeedScalar *B,
530                                             CeedScalar *V) {
531   for (int k = 0; k < P1D; ++k) {
532     V[k] = 0.0;
533     if (tidx < Q1D && tidy < Q1D)
534       for (int i = 0; i < Q1D; ++i)
535         V[k] += B[k + i*P1D] * U[i]; // Contract z direction
536   }
537   for (int k = P1D; k < Q1D; ++k)
538     V[k] = 0.0;
539 }
540 
541 //------------------------------------------------------------------------------
542 // 3D transpose tensor contract y
543 //------------------------------------------------------------------------------
544 inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
545                                             const int tidy, const int tidz,
546                                             const CeedScalar *U,
547                                             const CeedScalar *B,
548                                             CeedScalar *V) {
549   for (int k = 0; k < P1D; ++k) {
550     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
551     __syncthreads();
552     V[k] = 0.0;
553     if (tidx < Q1D && tidy < P1D)
554       for (int i = 0; i < Q1D; ++i)
555         V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
556     __syncthreads();
557   }
558 }
559 
560 //------------------------------------------------------------------------------
561 // 3D transpose tensor contract x
562 //------------------------------------------------------------------------------
563 inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
564                                             const int tidy, const int tidz,
565                                             const CeedScalar *U,
566                                             const CeedScalar *B,
567                                             CeedScalar *V) {
568   for (int k = 0; k < P1D; ++k) {
569     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
570     __syncthreads();
571     V[k] = 0.0;
572     if (tidx < P1D && tidy < P1D)
573       for (int i = 0; i < Q1D; ++i)
574         V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
575     __syncthreads();
576   }
577 }
578 
579 //------------------------------------------------------------------------------
580 // 3D interpolate to quadrature points
581 //------------------------------------------------------------------------------
582 inline __device__ void interp3d(const CeedInt nelem, const int transpose,
583                                 const CeedScalar *c_B,
584                                 const CeedScalar *__restrict__ d_U,
585                                 CeedScalar *__restrict__ d_V,
586                                 CeedScalar *slice) {
587   CeedScalar r_V[T1D];
588   CeedScalar r_t[T1D];
589 
590   const int tidx = threadIdx.x;
591   const int tidy = threadIdx.y;
592   const int tidz = threadIdx.z;
593   const int blockElem = tidz/BASIS_NCOMP;
594   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
595   const int comp = tidz%BASIS_NCOMP;
596 
597   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
598        elem += gridDim.x*elemsPerBlock) {
599     for (int i = 0; i < T1D; ++i) {
600       r_V[i] = 0.0;
601       r_t[i] = 0.0;
602     }
603     if (!transpose) {
604       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
605       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
606       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
607       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
608       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
609     } else {
610       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
611       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
612       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
613       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
614       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
615     }
616   }
617 }
618 
619 //------------------------------------------------------------------------------
620 // 3D derivatives at quadrature points
621 //------------------------------------------------------------------------------
622 inline __device__ void grad3d(const CeedInt nelem, const int transpose,
623                               const CeedScalar *c_B, const CeedScalar *c_G,
624                               const CeedScalar *__restrict__ d_U,
625                               CeedScalar *__restrict__ d_V,
626                               CeedScalar *slice) {
627   // Use P1D for one of these
628   CeedScalar r_U[T1D];
629   CeedScalar r_V[T1D];
630   CeedScalar r_t[T1D];
631 
632   const int tidx = threadIdx.x;
633   const int tidy = threadIdx.y;
634   const int tidz = threadIdx.z;
635   const int blockElem = tidz/BASIS_NCOMP;
636   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
637   const int comp = tidz%BASIS_NCOMP;
638   int dim;
639 
640   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
641        elem += gridDim.x*elemsPerBlock) {
642     for (int i = 0; i < T1D; ++i) {
643       r_U[i] = 0.0;
644       r_V[i] = 0.0;
645       r_t[i] = 0.0;
646     }
647     if (!transpose) {
648       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
649       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
650       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
651       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
652       dim = 0;
653       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
654       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
655       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
656       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
657       dim = 1;
658       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
659       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
660       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
661       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
662       dim = 2;
663       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
664     } else {
665       dim = 0;
666       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
667       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
668       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
669       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
670       dim = 1;
671       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
672       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
673       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
674       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
675       add(r_V, r_t);
676       dim = 2;
677       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
678       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
679       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
680       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
681       add(r_V, r_t);
682       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
683     }
684   }
685 }
686 
687 //------------------------------------------------------------------------------
688 // 3D quadrature weights
689 //------------------------------------------------------------------------------
690 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
691                          CeedScalar *w) {
692   const int i = threadIdx.x;
693   const int j = threadIdx.y;
694   const int k = threadIdx.z;
695   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
696   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
697     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
698     w[ind] = weight;
699   }
700 }
701 
702 
703 //------------------------------------------------------------------------------
704 // Basis kernels
705 //------------------------------------------------------------------------------
706 
707 //------------------------------------------------------------------------------
708 // Interp kernel by dim
709 //------------------------------------------------------------------------------
710 extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
711                                   const CeedScalar *c_B,
712                                   const CeedScalar *__restrict__ d_U,
713                                   CeedScalar *__restrict__ d_V) {
714   extern __shared__ double slice[];
715   if (BASIS_DIM == 1) {
716     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
717   } else if (BASIS_DIM == 2) {
718     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
719   } else if (BASIS_DIM == 3) {
720     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
721   }
722 }
723 
724 //------------------------------------------------------------------------------
725 // Grad kernel by dim
726 //------------------------------------------------------------------------------
727 extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
728                                 const CeedScalar *c_B, const CeedScalar *c_G,
729                                 const CeedScalar *__restrict__ d_U,
730                                 CeedScalar *__restrict__ d_V) {
731   extern __shared__ double slice[];
732   if (BASIS_DIM == 1) {
733     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
734   } else if (BASIS_DIM == 2) {
735     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
736   } else if (BASIS_DIM == 3) {
737     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
738   }
739 }
740 
741 //------------------------------------------------------------------------------
742 // Weight kernels by dim
743 //------------------------------------------------------------------------------
744 extern "C" __global__ void weight(const CeedInt nelem,
745                                   const CeedScalar *__restrict__ qweight1d,
746                                   CeedScalar *__restrict__ v) {
747   if (BASIS_DIM == 1) {
748     weight1d(nelem, qweight1d, v);
749   } else if (BASIS_DIM == 2) {
750     weight2d(nelem, qweight1d, v);
751   } else if (BASIS_DIM == 3) {
752     weight3d(nelem, qweight1d, v);
753   }
754 }
755 
756 );
757 // *INDENT-ON*
758 
759 //------------------------------------------------------------------------------
760 // Device initalization
761 //------------------------------------------------------------------------------
762 int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
763                        CeedScalar **c_B);
764 int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
765                            CeedInt Q1d, CeedScalar **c_B_ptr,
766                            CeedScalar **c_G_ptr);
767 
768 //------------------------------------------------------------------------------
769 // Apply basis
770 //------------------------------------------------------------------------------
771 int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem,
772                                      CeedTransposeMode tmode,
773                                      CeedEvalMode emode, CeedVector u,
774                                      CeedVector v) {
775   int ierr;
776   Ceed ceed;
777   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
778   Ceed_Cuda_shared *ceed_Cuda;
779   CeedGetData(ceed, &ceed_Cuda); CeedChkBackend(ierr);
780   CeedBasis_Cuda_shared *data;
781   CeedBasisGetData(basis, &data); CeedChkBackend(ierr);
782   const CeedInt transpose = tmode == CEED_TRANSPOSE;
783   CeedInt dim, ncomp;
784   ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr);
785   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr);
786 
787   // Read vectors
788   const CeedScalar *d_u;
789   CeedScalar *d_v;
790   if (emode != CEED_EVAL_WEIGHT) {
791     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChkBackend(ierr);
792   }
793   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr);
794 
795   // Clear v for transpose mode
796   if (tmode == CEED_TRANSPOSE) {
797     CeedInt length;
798     ierr = CeedVectorGetLength(v, &length); CeedChkBackend(ierr);
799     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChkBackend(ierr);
800   }
801 
802   // Apply basis operation
803   switch (emode) {
804   case CEED_EVAL_INTERP: {
805     CeedInt P1d, Q1d;
806     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr);
807     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr);
808     CeedInt thread1d = CeedIntMax(Q1d, P1d);
809     ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
810     CeedChkBackend(ierr);
811     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
812                           &d_u, &d_v
813                          };
814     if (dim == 1) {
815       CeedInt elemsPerBlock = 32;
816       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
817                                              ? 1 : 0 );
818       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
819       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, 1,
820                                         elemsPerBlock, sharedMem,
821                                         interpargs); CeedChkBackend(ierr);
822     } else if (dim == 2) {
823       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
824       // elemsPerBlock must be at least 1
825       CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1);
826       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
827                                              ? 1 : 0 );
828       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
829       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, thread1d,
830                                         ncomp*elemsPerBlock, sharedMem,
831                                         interpargs); CeedChkBackend(ierr);
832     } else if (dim == 3) {
833       CeedInt elemsPerBlock = 1;
834       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
835                                              ? 1 : 0 );
836       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
837       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, thread1d, thread1d,
838                                         ncomp*elemsPerBlock, sharedMem,
839                                         interpargs); CeedChkBackend(ierr);
840     }
841   } break;
842   case CEED_EVAL_GRAD: {
843     CeedInt P1d, Q1d;
844     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr);
845     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr);
846     CeedInt thread1d = CeedIntMax(Q1d, P1d);
847     ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
848                                   Q1d, &data->c_B, &data->c_G);
849     CeedChkBackend(ierr);
850     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
851                         &data->c_G, &d_u, &d_v
852                        };
853     if (dim == 1) {
854       CeedInt elemsPerBlock = 32;
855       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
856                                              ? 1 : 0 );
857       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
858       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, 1,
859                                         elemsPerBlock, sharedMem, gradargs);
860       CeedChkBackend(ierr);
861     } else if (dim == 2) {
862       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
863       // elemsPerBlock must be at least 1
864       CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1);
865       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
866                                              ? 1 : 0 );
867       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
868       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, thread1d,
869                                         ncomp*elemsPerBlock, sharedMem,
870                                         gradargs); CeedChkBackend(ierr);
871     } else if (dim == 3) {
872       CeedInt elemsPerBlock = 1;
873       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
874                                              ? 1 : 0 );
875       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
876       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, thread1d, thread1d,
877                                         ncomp*elemsPerBlock, sharedMem,
878                                         gradargs); CeedChkBackend(ierr);
879     }
880   } break;
881   case CEED_EVAL_WEIGHT: {
882     CeedInt Q1d;
883     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr);
884     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
885     if (dim == 1) {
886       const CeedInt elemsPerBlock = 32/Q1d;
887       const CeedInt gridsize = nelem/elemsPerBlock + ( (
888                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
889       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d,
890                                   elemsPerBlock, 1, weightargs);
891       CeedChkBackend(ierr);
892     } else if (dim == 2) {
893       const CeedInt optElems = 32/(Q1d*Q1d);
894       const CeedInt elemsPerBlock = optElems>0?optElems:1;
895       const CeedInt gridsize = nelem/elemsPerBlock + ( (
896                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
897       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d,
898                                   elemsPerBlock, weightargs);
899       CeedChkBackend(ierr);
900     } else if (dim == 3) {
901       const CeedInt gridsize = nelem;
902       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
903                                   weightargs);
904       CeedChkBackend(ierr);
905     }
906   } break;
907   // LCOV_EXCL_START
908   // Evaluate the divergence to/from the quadrature points
909   case CEED_EVAL_DIV:
910     return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
911   // Evaluate the curl to/from the quadrature points
912   case CEED_EVAL_CURL:
913     return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
914   // Take no action, BasisApply should not have been called
915   case CEED_EVAL_NONE:
916     return CeedError(ceed, CEED_ERROR_BACKEND,
917                      "CEED_EVAL_NONE does not make sense in this context");
918     // LCOV_EXCL_STOP
919   }
920 
921   // Restore vectors
922   if (emode != CEED_EVAL_WEIGHT) {
923     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChkBackend(ierr);
924   }
925   ierr = CeedVectorRestoreArray(v, &d_v); CeedChkBackend(ierr);
926   return CEED_ERROR_SUCCESS;
927 }
928 
929 //------------------------------------------------------------------------------
930 // Destroy basis
931 //------------------------------------------------------------------------------
932 static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
933   int ierr;
934   Ceed ceed;
935   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
936 
937   CeedBasis_Cuda_shared *data;
938   ierr = CeedBasisGetData(basis, &data); CeedChkBackend(ierr);
939 
940   CeedChk_Cu(ceed, cuModuleUnload(data->module));
941 
942   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr);
943   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr);
944   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr);
945   ierr = cudaFree(data->d_collograd1d); CeedChk_Cu(ceed, ierr);
946 
947   ierr = CeedFree(&data); CeedChkBackend(ierr);
948 
949   return CEED_ERROR_SUCCESS;
950 }
951 
952 //------------------------------------------------------------------------------
953 // Create tensor basis
954 //------------------------------------------------------------------------------
955 int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
956                                         const CeedScalar *interp1d,
957                                         const CeedScalar *grad1d,
958                                         const CeedScalar *qref1d,
959                                         const CeedScalar *qweight1d,
960                                         CeedBasis basis) {
961   int ierr;
962   Ceed ceed;
963   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
964   CeedBasis_Cuda_shared *data;
965   ierr = CeedCalloc(1, &data); CeedChkBackend(ierr);
966 
967   // Copy basis data to GPU
968   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
969   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr);
970   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
971                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
972 
973   const CeedInt iBytes = qBytes * P1d;
974   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr);
975   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
976                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
977 
978   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr);
979   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
980                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
981 
982   // Compute collocated gradient and copy to GPU
983   data->d_collograd1d = NULL;
984   if (dim == 3 && Q1d >= P1d) {
985     CeedScalar *collograd1d;
986     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChkBackend(ierr);
987     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChkBackend(ierr);
988     ierr = cudaMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
989     CeedChk_Cu(ceed, ierr);
990     ierr = cudaMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
991                       cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
992     ierr = CeedFree(&collograd1d); CeedChkBackend(ierr);
993   }
994 
995   // Compile basis kernels
996   CeedInt ncomp;
997   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr);
998   ierr = CeedCompileCuda(ceed, kernelsShared, &data->module, 8,
999                          "Q1D", Q1d,
1000                          "P1D", P1d,
1001                          "T1D", CeedIntMax(Q1d, P1d),
1002                          "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
1003                              Q1d : P1d, dim),
1004                          "BASIS_DIM", dim,
1005                          "BASIS_NCOMP", ncomp,
1006                          "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
1007                          "BASIS_NQPT", CeedIntPow(Q1d, dim)
1008                         ); CeedChkBackend(ierr);
1009   ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp);
1010   CeedChkBackend(ierr);
1011   ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad);
1012   CeedChkBackend(ierr);
1013   ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight);
1014   CeedChkBackend(ierr);
1015 
1016   ierr = CeedBasisSetData(basis, data); CeedChkBackend(ierr);
1017 
1018   // Register backend functions
1019   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
1020                                 CeedBasisApplyTensor_Cuda_shared);
1021   CeedChkBackend(ierr);
1022   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
1023                                 CeedBasisDestroy_Cuda_shared); CeedChkBackend(ierr);
1024   return CEED_ERROR_SUCCESS;
1025 }
1026 //------------------------------------------------------------------------------
1027