xref: /libCEED/backends/cuda-shared/ceed-cuda-shared-basis.c (revision e6a04bf5cb06fd03ee882d1ea55149c09771081b)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed-backend.h>
18 #include <ceed.h>
19 #include "ceed-cuda-shared.h"
20 #include "../cuda/ceed-cuda.h"
21 
22 //*********************
23 // shared mem kernels
24 static const char *kernelsShared = QUOTE(
25 
26 inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
27   for (int i = 0; i < Q1D; i++)
28     r_V[i] += r_U[i];
29 }
30 
31 //////////
32 //  1D  //
33 //////////
34 
35 inline __device__ void readDofs1d(const int elem, const int tidx,
36                                   const int tidy, const int tidz,const int comp,
37                                   const int nelem, const CeedScalar *d_U,
38                                   CeedScalar *slice) {
39   for (int i = 0; i < P1D; i++)
40     slice[i+tidz*Q1D] = d_U[i + comp*P1D + elem*BASIS_NCOMP*P1D];
41   for (int i = P1D; i < Q1D; i++)
42     slice[i+tidz*Q1D] = 0.0;
43 }
44 
45 inline __device__ void writeDofs1d(const int elem, const int tidx,
46                                    const int tidy, const int comp,
47                                    const int nelem, const CeedScalar &r_V,
48                                    CeedScalar *d_V) {
49   if (tidx<P1D) {
50     d_V[tidx + comp*P1D + elem*BASIS_NCOMP*P1D] = r_V;
51   }
52 }
53 
54 inline __device__ void readQuads1d(const int elem, const int tidx,
55                                    const int tidy, const int tidz, const int comp,
56                                    const int dim, const int nelem,
57                                    const CeedScalar *d_U, CeedScalar *slice) {
58   for (int i = 0; i < Q1D; i++)
59     slice[i+tidz*Q1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
60                             dim*BASIS_NCOMP*nelem*Q1D];
61 }
62 
63 inline __device__ void writeQuads1d(const int elem, const int tidx,
64                                     const int tidy, const int comp,
65                                     const int dim, const int nelem,
66                                     const CeedScalar &r_V, CeedScalar *d_V) {
67   d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
68 }
69 
70 inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
71                                    const int tidy, const int tidz,
72                                    const CeedScalar &U, const CeedScalar *B,
73                                    CeedScalar &V) {
74   V = 0.0;
75   for (int i = 0; i < P1D; ++i) {
76     V += B[i + tidx*P1D] * slice[i+tidz*Q1D];//contract x direction
77   }
78 }
79 
80 inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
81     const int tidy, const int tidz,
82     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
83   V = 0.0;
84   for (int i = 0; i < Q1D; ++i) {
85     V += B[tidx + i*P1D] * slice[i+tidz*Q1D];//contract x direction
86   }
87 }
88 
89 inline __device__ void interp1d(const CeedInt nelem, const int transpose,
90                                 const CeedScalar *c_B,
91                                 const CeedScalar *__restrict__ d_U,
92                                 CeedScalar *__restrict__ d_V,
93                                 CeedScalar *slice) {
94   CeedScalar r_V;
95   CeedScalar r_t;
96 
97   const int tidx = threadIdx.x;
98   const int tidy = threadIdx.y;
99   const int tidz = threadIdx.z;
100 
101 
102   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
103        elem += gridDim.x*blockDim.z) {
104     for(int comp=0; comp<BASIS_NCOMP; comp++) {
105       if(!transpose) {
106         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
107         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
108         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
109       } else {
110         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
111         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
112         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
113       }
114     }
115   }
116 }
117 
118 inline __device__ void grad1d(const CeedInt nelem, const int transpose,
119                               const CeedScalar *c_B, const CeedScalar *c_G,
120                               const CeedScalar *__restrict__ d_U,
121                               CeedScalar *__restrict__ d_V,
122                               CeedScalar *slice) {
123   CeedScalar r_U;
124   CeedScalar r_V;
125 
126   const int tidx = threadIdx.x;
127   const int tidy = threadIdx.y;
128   const int tidz = threadIdx.z;
129   int dim;
130 
131   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
132        elem += gridDim.x*blockDim.z) {
133     for(int comp=0; comp<BASIS_NCOMP; comp++) {
134       if(!transpose) {
135         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
136         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
137         dim = 0;
138         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
139       } else {
140         dim = 0;
141         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
142         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
143         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
144       }
145     }
146   }
147 }
148 //////////
149 //  2D  //
150 //////////
151 
152 inline __device__ void readDofs2d(const int elem, const int tidx,
153                                   const int tidy, const int comp,
154                                   const int nelem, const CeedScalar *d_U,
155                                   CeedScalar &U) {
156   U = (tidx<P1D
157        && tidy<P1D) ? d_U[tidx + tidy*P1D + comp*P1D*P1D +
158                           elem*BASIS_NCOMP*P1D*P1D ] :
159       0.0;
160 }
161 
162 inline __device__ void writeDofs2d(const int elem, const int tidx,
163                                    const int tidy, const int comp,
164                                    const int nelem, const CeedScalar &r_V,
165                                    CeedScalar *d_V) {
166   if (tidx<P1D && tidy<P1D) {
167     d_V[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D ] = r_V;
168   }
169 }
170 
171 inline __device__ void readQuads2d(const int elem, const int tidx,
172                                    const int tidy, const int comp,
173                                    const int dim, const int nelem,
174                                    const CeedScalar *d_U, CeedScalar &U ) {
175   U = d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
176                dim*BASIS_NCOMP*nelem*Q1D*Q1D];
177 }
178 
179 inline __device__ void writeQuads2d(const int elem, const int tidx,
180                                     const int tidy, const int comp,
181                                     const int dim, const int nelem,
182                                     const CeedScalar &r_V, CeedScalar *d_V) {
183   d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
184            dim*BASIS_NCOMP*nelem*Q1D*Q1D ] = r_V;
185 }
186 
187 inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
188                                    const int tidy, const int tidz,
189                                    const CeedScalar &U, const CeedScalar *B,
190                                    CeedScalar &V) {
191   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
192   __syncthreads();
193   V = 0.0;
194   for (int i = 0; i < P1D; ++i) {
195     V += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction
196   }
197   __syncthreads();
198 }
199 
200 inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
201                                    const int tidy, const int tidz,
202                                    const CeedScalar &U, const CeedScalar *B,
203                                    CeedScalar &V) {
204   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
205   __syncthreads();
206   V = 0.0;
207   for (int i = 0; i < P1D; ++i) {
208     V += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction
209   }
210   __syncthreads();
211 }
212 
213 inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
214     const int tidy, const int tidz,
215     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
216   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
217   __syncthreads();
218   V = 0.0;
219   if (tidy<P1D) {
220     for (int i = 0; i < Q1D; ++i) {
221       V += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction
222     }
223   }
224   __syncthreads();
225 }
226 
227 inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
228     const int tidy, const int tidz,
229     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
230   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
231   __syncthreads();
232   V = 0.0;
233   if (tidx<P1D) {
234     for (int i = 0; i < Q1D; ++i) {
235       V += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction
236     }
237   }
238   __syncthreads();
239 }
240 
241 inline __device__ void interp2d(const CeedInt nelem, const int transpose,
242                                 const CeedScalar *c_B,
243                                 const CeedScalar *__restrict__ d_U,
244                                 CeedScalar *__restrict__ d_V,
245                                 CeedScalar *slice) {
246   CeedScalar r_V;
247   CeedScalar r_t;
248 
249   const int tidx = threadIdx.x;
250   const int tidy = threadIdx.y;
251   const int tidz = threadIdx.z;
252   const int blockElem = tidz/BASIS_NCOMP;
253   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
254   const int comp = tidz%BASIS_NCOMP;
255 
256   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
257        elem += gridDim.x*elemsPerBlock) {
258     const int comp = tidz%BASIS_NCOMP;
259     r_V = 0.0;
260     r_t = 0.0;
261     if(!transpose) {
262       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
263       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
264       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
265       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
266     } else {
267       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
268       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
269       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
270       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
271     }
272   }
273 }
274 
275 inline __device__ void grad2d(const CeedInt nelem, const int transpose,
276                               const CeedScalar *c_B, const CeedScalar *c_G,
277                               const CeedScalar *__restrict__ d_U,
278                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
279   CeedScalar r_U;
280   CeedScalar r_V;
281   CeedScalar r_t;
282 
283   const int tidx = threadIdx.x;
284   const int tidy = threadIdx.y;
285   const int tidz = threadIdx.z;
286   const int blockElem = tidz/BASIS_NCOMP;
287   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
288   const int comp = tidz%BASIS_NCOMP;
289   int dim;
290 
291   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
292        elem += gridDim.x*elemsPerBlock) {
293     if(!transpose) {
294       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
295       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
296       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
297       dim = 0;
298       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
299       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
300       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
301       dim = 1;
302       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
303     } else {
304       dim = 0;
305       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
306       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
307       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
308       dim = 1;
309       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
310       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
311       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
312       r_V+=r_U;
313       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
314     }
315   }
316 }
317 //////////
318 //  3D  //
319 //////////
320 
321 inline __device__ void readDofs3d(const int elem, const int tidx,
322                                   const int tidy, const int comp,
323                                   const int nelem, const CeedScalar *d_U,
324                                   CeedScalar *r_U) {
325   for (int i = 0; i < P1D; i++)
326     r_U[i] = (tidx<P1D
327               && tidy<P1D) ? d_U[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
328                                       elem*BASIS_NCOMP*P1D*P1D*P1D ] : 0.0;
329   for (int i = P1D; i < Q1D; i++)
330     r_U[i] = 0.0;
331 }
332 
333 inline __device__ void readQuads3d(const int elem, const int tidx,
334                                    const int tidy, const int comp,
335                                    const int dim, const int nelem,
336                                    const CeedScalar *d_U, CeedScalar *r_U) {
337   for (int i = 0; i < Q1D; i++)
338     r_U[i] = d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
339                  comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D];
340 }
341 
342 inline __device__ void writeDofs3d(const int elem, const int tidx,
343                                    const int tidy, const int comp,
344                                    const int nelem, const CeedScalar *r_V,
345                                    CeedScalar *d_V) {
346   if (tidx<P1D && tidy<P1D) {
347     for (int i = 0; i < P1D; i++)
348       d_V[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
349           elem*BASIS_NCOMP*P1D*P1D*P1D ] = r_V[i];
350   }
351 }
352 
353 inline __device__ void writeQuads3d(const int elem, const int tidx,
354                                     const int tidy, const int comp,
355                                     const int dim, const int nelem,
356                                     const CeedScalar *r_V, CeedScalar *d_V) {
357   for (int i = 0; i < Q1D; i++)
358     d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
359         dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D ] = r_V[i];
360 }
361 
362 inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
363                                    const int tidy, const int tidz,
364                                    const CeedScalar *U, const CeedScalar *B,
365                                    CeedScalar *V) {
366   for (int k = 0; k < P1D; ++k) {
367     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
368     __syncthreads();
369     V[k] = 0.0;
370     for (int i = 0; i < P1D; ++i) {
371       V[k] += B[i + tidx*P1D] * slice[i + tidy*Q1D +
372                                       tidz*Q1D*Q1D];//contract x direction
373     }
374     __syncthreads();
375   }
376 }
377 
378 inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
379                                    const int tidy, const int tidz,
380                                    const CeedScalar *U, const CeedScalar *B,
381                                    CeedScalar *V) {
382   for (int k = 0; k < P1D; ++k) {
383     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
384     __syncthreads();
385     V[k] = 0.0;
386     for (int i = 0; i < P1D; ++i) {
387       V[k] += B[i + tidy*P1D] * slice[tidx + i*Q1D +
388                                       tidz*Q1D*Q1D];//contract y direction
389     }
390     __syncthreads();
391   }
392 }
393 
394 inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
395                                    const int tidy, const int tidz,
396                                    const CeedScalar *U, const CeedScalar *B,
397                                    CeedScalar *V) {
398   for (int k = 0; k < Q1D; ++k) {
399     V[k] = 0.0;
400     for (int i = 0; i < P1D; ++i) {
401       V[k] += B[i + k*P1D] * U[i];//contract z direction
402     }
403   }
404 }
405 
406 inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
407     const int tidy, const int tidz,
408     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
409   for (int k = 0; k < Q1D; ++k) {
410     V[k] = 0.0;
411     if (k<P1D) {
412       for (int i = 0; i < Q1D; ++i) {
413         V[k] += B[k + i*P1D] * U[i];//contract z direction
414       }
415     }
416   }
417 }
418 
419 inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
420     const int tidy, const int tidz,
421     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
422   for (int k = 0; k < P1D; ++k) {
423     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
424     __syncthreads();
425     V[k] = 0.0;
426     if (tidy<P1D) {
427       for (int i = 0; i < Q1D; ++i) {
428         V[k] += B[tidy + i*P1D] * slice[tidx + i*Q1D +
429                                         tidz*Q1D*Q1D];//contract y direction
430       }
431     }
432     __syncthreads();
433   }
434 }
435 
436 inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
437     const int tidy, const int tidz,
438     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
439   for (int k = 0; k < P1D; ++k) {
440     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
441     __syncthreads();
442     V[k] = 0.0;
443     if (tidx<P1D) {
444       for (int i = 0; i < Q1D; ++i) {
445         V[k] += B[tidx + i*P1D] * slice[i + tidy*Q1D +
446                                         tidz*Q1D*Q1D];//contract x direction
447       }
448     }
449     __syncthreads();
450   }
451 }
452 
453 inline __device__ void interp3d(const CeedInt nelem, const int transpose,
454                                 const CeedScalar *c_B,
455                                 const CeedScalar *__restrict__ d_U,
456                                 CeedScalar *__restrict__ d_V,
457                                 CeedScalar *slice) {
458   CeedScalar r_V[Q1D];
459   CeedScalar r_t[Q1D];
460 
461   const int tidx = threadIdx.x;
462   const int tidy = threadIdx.y;
463   const int tidz = threadIdx.z;
464   const int blockElem = tidz/BASIS_NCOMP;
465   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
466   const int comp = tidz%BASIS_NCOMP;
467 
468   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
469        elem += gridDim.x*elemsPerBlock) {
470     for (int i = 0; i < Q1D; ++i) {
471       r_V[i] = 0.0;
472       r_t[i] = 0.0;
473     }
474     if(!transpose) {
475       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
476       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
477       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
478       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
479       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
480     } else {
481       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
482       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
483       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
484       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
485       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
486     }
487   }
488 }
489 
490 inline __device__ void grad3d(const CeedInt nelem, const int transpose,
491                               const CeedScalar *c_B, const CeedScalar *c_G,
492                               const CeedScalar *__restrict__ d_U,
493                               CeedScalar *__restrict__ d_V,
494                               CeedScalar *slice) {
495   //use P1D for one of these
496   CeedScalar r_U[Q1D];
497   CeedScalar r_V[Q1D];
498   CeedScalar r_t[Q1D];
499 
500   const int tidx = threadIdx.x;
501   const int tidy = threadIdx.y;
502   const int tidz = threadIdx.z;
503   const int blockElem = tidz/BASIS_NCOMP;
504   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
505   const int comp = tidz%BASIS_NCOMP;
506   int dim;
507 
508   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
509        elem += gridDim.x*elemsPerBlock) {
510     if(!transpose) {
511       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
512       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
513       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
514       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
515       dim = 0;
516       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
517       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
518       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
519       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
520       dim = 1;
521       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
522       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
523       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
524       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
525       dim = 2;
526       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
527     } else {
528       dim = 0;
529       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
530       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
531       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
532       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
533       dim = 1;
534       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
535       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
536       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
537       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
538       add(r_V, r_t);
539       dim = 2;
540       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
541       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
542       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
543       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
544       add(r_V, r_t);
545       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
546     }
547   }
548 }
549 
550 /////////////
551 // Kernels //
552 /////////////
553 extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
554                                   const CeedScalar *c_B,
555                                   const CeedScalar *__restrict__ d_U,
556                                   CeedScalar *__restrict__ d_V) {
557   extern __shared__ double slice[];
558   if (BASIS_DIM==1) {
559     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
560   } else if (BASIS_DIM==2) {
561     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
562   } else if (BASIS_DIM==3) {
563     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
564   }
565 }
566 
567 extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
568                                 const CeedScalar *c_B, const CeedScalar *c_G,
569                                 const CeedScalar *__restrict__ d_U,
570                                 CeedScalar *__restrict__ d_V) {
571   extern __shared__ double slice[];
572   if (BASIS_DIM==1) {
573     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
574   } else if (BASIS_DIM==2) {
575     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
576   } else if (BASIS_DIM==3) {
577     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
578   }
579 }
580 
581 /////////////
582 // Weights //
583 /////////////
584 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
585                          CeedScalar *w) {
586   const int tid = threadIdx.x;
587   const CeedScalar weight = qweight1d[tid];
588   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
589        elem += gridDim.x*blockDim.y) {
590     const int ind = elem*Q1D + tid;
591     w[ind] = weight;
592   }
593 }
594 
595 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
596                          CeedScalar *w) {
597   const int i = threadIdx.x;
598   const int j = threadIdx.y;
599   const CeedScalar weight = qweight1d[i]*qweight1d[j];
600   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
601        elem += gridDim.x*blockDim.z) {
602     const int ind = elem*Q1D*Q1D + i + j*Q1D;
603     w[ind] = weight;
604   }
605 }
606 
607 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
608                          CeedScalar *w) {
609   const int i = threadIdx.x;
610   const int j = threadIdx.y;
611   const int k = threadIdx.z;
612   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
613   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
614     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
615     w[ind] = weight;
616   }
617 }
618 
619 extern "C" __global__ void weight(const CeedInt nelem,
620                                   const CeedScalar *__restrict__ qweight1d,
621                                   CeedScalar *__restrict__ v) {
622   if (BASIS_DIM==1) {
623     weight1d(nelem, qweight1d, v);
624   } else if (BASIS_DIM==2) {
625     weight2d(nelem, qweight1d, v);
626   } else if (BASIS_DIM==3) {
627     weight3d(nelem, qweight1d, v);
628   }
629 }
630 
631 );
632 
633 int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
634                        CeedScalar **c_B);
635 int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
636                            CeedInt Q1d, CeedScalar **c_B_ptr,
637                            CeedScalar **c_G_ptr);
638 
639 int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem,
640                                      CeedTransposeMode tmode,
641                                      CeedEvalMode emode, CeedVector u,
642                                      CeedVector v) {
643   int ierr;
644   Ceed ceed;
645   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
646   Ceed_Cuda_shared *ceed_Cuda;
647   CeedGetData(ceed, (void *) &ceed_Cuda); CeedChk(ierr);
648   CeedBasis_Cuda_shared *data;
649   CeedBasisGetData(basis, (void *)&data); CeedChk(ierr);
650   const CeedInt transpose = tmode == CEED_TRANSPOSE;
651   CeedInt dim, ncomp;
652   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
653   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
654 
655   const CeedScalar *d_u;
656   CeedScalar *d_v;
657   if(emode!=CEED_EVAL_WEIGHT) {
658     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
659   }
660   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
661 
662   if (tmode == CEED_TRANSPOSE) {
663     CeedInt length;
664     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
665     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
666   }
667   if (emode == CEED_EVAL_INTERP) {
668     CeedInt P1d, Q1d;
669     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
670     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
671     ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
672     CeedChk(ierr);
673     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, &d_u, &d_v};
674     if (dim==1) {
675       CeedInt elemsPerBlock = 32;
676       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
677                                              ? 1 : 0 );
678       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
679       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, 1,
680                                         elemsPerBlock, sharedMem,
681                                         interpargs);
682       CeedChk(ierr);
683     } else if (dim==2) {
684       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
685       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
686       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
687                                              ? 1 : 0 );
688       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
689       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d,
690                                         ncomp*elemsPerBlock, sharedMem,
691                                         interpargs);
692       CeedChk(ierr);
693     } else if (dim==3) {
694       CeedInt elemsPerBlock = 1;
695       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
696                                              ? 1 : 0 );
697       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
698       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d,
699                                         ncomp*elemsPerBlock, sharedMem,
700                                         interpargs);
701       CeedChk(ierr);
702     }
703   } else if (emode == CEED_EVAL_GRAD) {
704     CeedInt P1d, Q1d;
705     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
706     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
707     ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
708                                   Q1d, &data->c_B, &data->c_G);
709     CeedChk(ierr);
710     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, &data->c_G, &d_u, &d_v};
711     if (dim==1) {
712       CeedInt elemsPerBlock = 32;
713       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
714                                              ? 1 : 0 );
715       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
716       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, 1, elemsPerBlock,
717                                         sharedMem,
718                                         gradargs);
719       CeedChk(ierr);
720     } else if (dim==2) {
721       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
722       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
723       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
724                                              ? 1 : 0 );
725       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
726       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d,
727                                         ncomp*elemsPerBlock, sharedMem,
728                                         gradargs);
729       CeedChk(ierr);
730     } else if (dim==3) {
731       CeedInt elemsPerBlock = 1;
732       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
733                                              ? 1 : 0 );
734       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
735       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d,
736                                         ncomp*elemsPerBlock, sharedMem,
737                                         gradargs);
738       CeedChk(ierr);
739     }
740   } else if (emode == CEED_EVAL_WEIGHT) {
741     CeedInt Q1d;
742     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
743     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
744     if(dim==1) {
745       const CeedInt elemsPerBlock = 32/Q1d;
746       const CeedInt gridsize = nelem/elemsPerBlock + ( (
747                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
748       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d,
749                                   elemsPerBlock, 1, weightargs);
750       CeedChk(ierr);
751     } else if(dim==2) {
752       const CeedInt optElems = 32/(Q1d*Q1d);
753       const CeedInt elemsPerBlock = optElems>0?optElems:1;
754       const CeedInt gridsize = nelem/elemsPerBlock + ( (
755                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
756       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d,
757                                   elemsPerBlock, weightargs);
758       CeedChk(ierr);
759     } else if(dim==3) {
760       const CeedInt gridsize = nelem;
761       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
762                                   weightargs);
763       CeedChk(ierr);
764     }
765   }
766 
767   if(emode!=CEED_EVAL_WEIGHT) {
768     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
769   }
770   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
771 
772   return 0;
773 }
774 
775 static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
776   int ierr;
777   Ceed ceed;
778   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
779 
780   CeedBasis_Cuda_shared *data;
781   ierr = CeedBasisGetData(basis, (void *) &data); CeedChk(ierr);
782 
783   CeedChk_Cu(ceed, cuModuleUnload(data->module));
784 
785   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr);
786   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr);
787   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr);
788 
789   ierr = CeedFree(&data); CeedChk(ierr);
790 
791   return 0;
792 }
793 
794 int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
795                                         const CeedScalar *interp1d,
796                                         const CeedScalar *grad1d,
797                                         const CeedScalar *qref1d,
798                                         const CeedScalar *qweight1d,
799                                         CeedBasis basis) {
800   int ierr;
801   Ceed ceed;
802   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
803   if (Q1d<P1d) {
804     return CeedError(ceed, 1, "Backend does not implement underintegrated basis.");
805   }
806   CeedBasis_Cuda_shared *data;
807   ierr = CeedCalloc(1, &data); CeedChk(ierr);
808 
809   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
810   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr);
811   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
812                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
813 
814   const CeedInt iBytes = qBytes * P1d;
815   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr);
816   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
817                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
818 
819   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr);
820   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
821                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
822 
823   data->d_collograd1d = NULL;
824   if (dim==3 && Q1d >= P1d) {
825     CeedScalar *collograd1d;
826     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChk(ierr);
827     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChk(ierr);
828     ierr = cudaMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
829     CeedChk_Cu(ceed, ierr);
830     ierr = cudaMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
831                       cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
832   }
833 
834   CeedInt ncomp;
835   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
836   ierr = CeedCompileCuda(ceed, kernelsShared, &data->module, 7,
837                          "Q1D", Q1d,
838                          "P1D", P1d,
839                          "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
840                              Q1d : P1d, dim),
841                          "BASIS_DIM", dim,
842                          "BASIS_NCOMP", ncomp,
843                          "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
844                          "BASIS_NQPT", CeedIntPow(Q1d, dim)
845                         ); CeedChk(ierr);
846   ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp);
847   CeedChk(ierr);
848   ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad);
849   CeedChk(ierr);
850   ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight);
851   CeedChk(ierr);
852 
853   ierr = CeedBasisSetData(basis, (void *)&data);
854   CeedChk(ierr);
855   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
856                                 CeedBasisApplyTensor_Cuda_shared);
857   CeedChk(ierr);
858   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
859                                 CeedBasisDestroy_Cuda_shared);
860   CeedChk(ierr);
861   return 0;
862 }
863