xref: /libCEED/backends/cuda-shared/ceed-cuda-shared-basis.c (revision c172e0a6730c614967cdf24760cc19b8004cf29b)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed-backend.h>
18 #include <ceed.h>
19 #include "ceed-cuda-shared.h"
20 #include "../cuda/ceed-cuda.h"
21 
22 //*********************
23 // shared mem kernels
24 static const char *kernelsShared = QUOTE(
25 
26 inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
27   for (int i = 0; i < Q1D; i++)
28     r_V[i] += r_U[i];
29 }
30 
31 //////////
32 //  1D  //
33 //////////
34 
35 inline __device__ void readDofs1d(const int elem, const int tidx,
36                                   const int tidy, const int tidz,const int comp,
37                                   const int nelem, const CeedScalar *d_U, CeedScalar *slice) {
38   for (int i = 0; i < P1D; i++)
39     slice[i+tidz*Q1D] = d_U[i + comp*P1D + elem*BASIS_NCOMP*P1D];
40   for (int i = P1D; i < Q1D; i++)
41     slice[i+tidz*Q1D] = 0.0;
42 }
43 
44 inline __device__ void writeDofs1d(const int elem, const int tidx,
45                                    const int tidy, const int comp,
46                                    const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
47   if (tidx<P1D) {
48     d_V[tidx + comp*P1D + elem*BASIS_NCOMP*P1D] = r_V;
49   }
50 }
51 
52 inline __device__ void readQuads1d(const int elem, const int tidx,
53                                    const int tidy, const int tidz, const int comp,
54                                    const int dim, const int nelem, const CeedScalar *d_U, CeedScalar *slice) {
55   for (int i = 0; i < Q1D; i++)
56     slice[i+tidz*Q1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
57                             dim*BASIS_NCOMP*nelem*Q1D];
58 }
59 
60 inline __device__ void writeQuads1d(const int elem, const int tidx,
61                                     const int tidy, const int comp,
62                                     const int dim, const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
63   d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
64 }
65 
66 inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
67                                    const int tidy, const int tidz,
68                                    const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
69   V = 0.0;
70   for (int i = 0; i < P1D; ++i) {
71     V += B[i + tidx*P1D] * slice[i+tidz*Q1D];//contract x direction
72   }
73 }
74 
75 inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
76     const int tidy, const int tidz,
77     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
78   V = 0.0;
79   for (int i = 0; i < Q1D; ++i) {
80     V += B[tidx + i*P1D] * slice[i+tidz*Q1D];//contract x direction
81   }
82 }
83 
84 inline __device__ void interp1d(const CeedInt nelem, const int transpose,
85                                 const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
86                                 CeedScalar *__restrict__ d_V,
87                                 CeedScalar *slice) {
88   CeedScalar r_V;
89   CeedScalar r_t;
90 
91   const int tidx = threadIdx.x;
92   const int tidy = threadIdx.y;
93   const int tidz = threadIdx.z;
94 
95 
96   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
97        elem += gridDim.x*blockDim.z) {
98     for(int comp=0; comp<BASIS_NCOMP; comp++) {
99       if(!transpose) {
100         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
101         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
102         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
103       } else {
104         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
105         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
106         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
107       }
108     }
109   }
110 }
111 
112 inline __device__ void grad1d(const CeedInt nelem, const int transpose,
113                               const CeedScalar *c_B, const CeedScalar *c_G,
114                               const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V,
115                               CeedScalar *slice) {
116   CeedScalar r_U;
117   CeedScalar r_V;
118 
119   const int tidx = threadIdx.x;
120   const int tidy = threadIdx.y;
121   const int tidz = threadIdx.z;
122   int dim;
123 
124   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
125        elem += gridDim.x*blockDim.z) {
126     for(int comp=0; comp<BASIS_NCOMP; comp++) {
127       if(!transpose) {
128         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
129         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
130         dim = 0;
131         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
132       } else {
133         dim = 0;
134         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
135         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
136         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
137       }
138     }
139   }
140 }
141 //////////
142 //  2D  //
143 //////////
144 
145 inline __device__ void readDofs2d(const int elem, const int tidx,
146                                   const int tidy, const int comp,
147                                   const int nelem, const CeedScalar *d_U, CeedScalar &U) {
148   U = (tidx<P1D
149        && tidy<P1D) ? d_U[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D ] :
150       0.0;
151 }
152 
153 inline __device__ void writeDofs2d(const int elem, const int tidx,
154                                    const int tidy, const int comp,
155                                    const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
156   if (tidx<P1D && tidy<P1D) {
157     d_V[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D ] = r_V;
158   }
159 }
160 
161 inline __device__ void readQuads2d(const int elem, const int tidx,
162                                    const int tidy, const int comp,
163                                    const int dim, const int nelem, const CeedScalar *d_U, CeedScalar &U ) {
164   U = d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
165                dim*BASIS_NCOMP*nelem*Q1D*Q1D];
166 }
167 
168 inline __device__ void writeQuads2d(const int elem, const int tidx,
169                                     const int tidy, const int comp,
170                                     const int dim, const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
171   d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
172            dim*BASIS_NCOMP*nelem*Q1D*Q1D ] = r_V;
173 }
174 
175 inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
176                                    const int tidy, const int tidz,
177                                    const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
178   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
179   __syncthreads();
180   V = 0.0;
181   for (int i = 0; i < P1D; ++i) {
182     V += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction
183   }
184   __syncthreads();
185 }
186 
187 inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
188                                    const int tidy, const int tidz,
189                                    const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
190   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
191   __syncthreads();
192   V = 0.0;
193   for (int i = 0; i < P1D; ++i) {
194     V += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction
195   }
196   __syncthreads();
197 }
198 
199 inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
200     const int tidy, const int tidz,
201     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
202   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
203   __syncthreads();
204   V = 0.0;
205   if (tidy<P1D) {
206     for (int i = 0; i < Q1D; ++i) {
207       V += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction
208     }
209   }
210   __syncthreads();
211 }
212 
213 inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
214     const int tidy, const int tidz,
215     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
216   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
217   __syncthreads();
218   V = 0.0;
219   if (tidx<P1D) {
220     for (int i = 0; i < Q1D; ++i) {
221       V += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction
222     }
223   }
224   __syncthreads();
225 }
226 
227 inline __device__ void interp2d(const CeedInt nelem, const int transpose,
228                                 const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
229                                 CeedScalar *__restrict__ d_V,
230                                 CeedScalar *slice) {
231   CeedScalar r_V;
232   CeedScalar r_t;
233 
234   const int tidx = threadIdx.x;
235   const int tidy = threadIdx.y;
236   const int tidz = threadIdx.z;
237   const int blockElem = tidz/BASIS_NCOMP;
238   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
239   const int comp = tidz%BASIS_NCOMP;
240 
241   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
242        elem += gridDim.x*elemsPerBlock) {
243     const int comp = tidz%BASIS_NCOMP;
244     r_V = 0.0;
245     r_t = 0.0;
246     if(!transpose) {
247       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
248       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
249       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
250       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
251     } else {
252       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
253       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
254       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
255       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
256     }
257   }
258 }
259 
260 inline __device__ void grad2d(const CeedInt nelem, const int transpose,
261                               const CeedScalar *c_B, const CeedScalar *c_G,
262                               const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V,
263                               CeedScalar *slice) {
264   CeedScalar r_U;
265   CeedScalar r_V;
266   CeedScalar r_t;
267 
268   const int tidx = threadIdx.x;
269   const int tidy = threadIdx.y;
270   const int tidz = threadIdx.z;
271   const int blockElem = tidz/BASIS_NCOMP;
272   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
273   const int comp = tidz%BASIS_NCOMP;
274   int dim;
275 
276   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
277        elem += gridDim.x*elemsPerBlock) {
278     if(!transpose) {
279       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
280       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
281       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
282       dim = 0;
283       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
284       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
285       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
286       dim = 1;
287       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
288     } else {
289       dim = 0;
290       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
291       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
292       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
293       dim = 1;
294       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
295       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
296       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
297       r_V+=r_U;
298       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
299     }
300   }
301 }
302 //////////
303 //  3D  //
304 //////////
305 
306 inline __device__ void readDofs3d(const int elem, const int tidx,
307                                   const int tidy, const int comp,
308                                   const int nelem, const CeedScalar *d_U, CeedScalar *r_U) {
309   for (int i = 0; i < P1D; i++)
310     r_U[i] = (tidx<P1D
311               && tidy<P1D) ? d_U[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
312                                       elem*BASIS_NCOMP*P1D*P1D*P1D ] : 0.0;
313   for (int i = P1D; i < Q1D; i++)
314     r_U[i] = 0.0;
315 }
316 
317 inline __device__ void readQuads3d(const int elem, const int tidx,
318                                    const int tidy, const int comp,
319                                    const int dim, const int nelem, const CeedScalar *d_U, CeedScalar *r_U) {
320   for (int i = 0; i < Q1D; i++)
321     r_U[i] = d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
322                  comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D];
323 }
324 
325 inline __device__ void writeDofs3d(const int elem, const int tidx,
326                                    const int tidy, const int comp,
327                                    const int nelem, const CeedScalar *r_V, CeedScalar *d_V) {
328   if (tidx<P1D && tidy<P1D) {
329     for (int i = 0; i < P1D; i++)
330       d_V[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
331           elem*BASIS_NCOMP*P1D*P1D*P1D ] = r_V[i];
332   }
333 }
334 
335 inline __device__ void writeQuads3d(const int elem, const int tidx,
336                                     const int tidy, const int comp,
337                                     const int dim, const int nelem, const CeedScalar *r_V, CeedScalar *d_V) {
338   for (int i = 0; i < Q1D; i++)
339     d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
340         dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D ] = r_V[i];
341 }
342 
343 inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
344                                    const int tidy, const int tidz,
345                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
346   for (int k = 0; k < P1D; ++k) {
347     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
348     __syncthreads();
349     V[k] = 0.0;
350     for (int i = 0; i < P1D; ++i) {
351       V[k] += B[i + tidx*P1D] * slice[i + tidy*Q1D +
352                                       tidz*Q1D*Q1D];//contract x direction
353     }
354     __syncthreads();
355   }
356 }
357 
358 inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
359                                    const int tidy, const int tidz,
360                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
361   for (int k = 0; k < P1D; ++k) {
362     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
363     __syncthreads();
364     V[k] = 0.0;
365     for (int i = 0; i < P1D; ++i) {
366       V[k] += B[i + tidy*P1D] * slice[tidx + i*Q1D +
367                                       tidz*Q1D*Q1D];//contract y direction
368     }
369     __syncthreads();
370   }
371 }
372 
373 inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
374                                    const int tidy, const int tidz,
375                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
376   for (int k = 0; k < Q1D; ++k) {
377     V[k] = 0.0;
378     for (int i = 0; i < P1D; ++i) {
379       V[k] += B[i + k*P1D] * U[i];//contract z direction
380     }
381   }
382 }
383 
384 inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
385     const int tidy, const int tidz,
386     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
387   for (int k = 0; k < Q1D; ++k) {
388     V[k] = 0.0;
389     if (k<P1D) {
390       for (int i = 0; i < Q1D; ++i) {
391         V[k] += B[k + i*P1D] * U[i];//contract z direction
392       }
393     }
394   }
395 }
396 
397 inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
398     const int tidy, const int tidz,
399     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
400   for (int k = 0; k < P1D; ++k) {
401     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
402     __syncthreads();
403     V[k] = 0.0;
404     if (tidy<P1D) {
405       for (int i = 0; i < Q1D; ++i) {
406         V[k] += B[tidy + i*P1D] * slice[tidx + i*Q1D +
407                                         tidz*Q1D*Q1D];//contract y direction
408       }
409     }
410     __syncthreads();
411   }
412 }
413 
414 inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
415     const int tidy, const int tidz,
416     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
417   for (int k = 0; k < P1D; ++k) {
418     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
419     __syncthreads();
420     V[k] = 0.0;
421     if (tidx<P1D) {
422       for (int i = 0; i < Q1D; ++i) {
423         V[k] += B[tidx + i*P1D] * slice[i + tidy*Q1D +
424                                         tidz*Q1D*Q1D];//contract x direction
425       }
426     }
427     __syncthreads();
428   }
429 }
430 
431 inline __device__ void interp3d(const CeedInt nelem, const int transpose,
432                                 const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
433                                 CeedScalar *__restrict__ d_V,
434                                 CeedScalar *slice) {
435   CeedScalar r_V[Q1D];
436   CeedScalar r_t[Q1D];
437 
438   const int tidx = threadIdx.x;
439   const int tidy = threadIdx.y;
440   const int tidz = threadIdx.z;
441   const int blockElem = tidz/BASIS_NCOMP;
442   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
443   const int comp = tidz%BASIS_NCOMP;
444 
445   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
446        elem += gridDim.x*elemsPerBlock) {
447     for (int i = 0; i < Q1D; ++i) {
448       r_V[i] = 0.0;
449       r_t[i] = 0.0;
450     }
451     if(!transpose) {
452       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
453       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
454       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
455       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
456       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
457     } else {
458       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
459       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
460       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
461       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
462       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
463     }
464   }
465 }
466 
467 inline __device__ void grad3d(const CeedInt nelem, const int transpose,
468                               const CeedScalar *c_B, const CeedScalar *c_G,
469                               const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V,
470                               CeedScalar *slice) {
471   //use P1D for one of these
472   CeedScalar r_U[Q1D];
473   CeedScalar r_V[Q1D];
474   CeedScalar r_t[Q1D];
475 
476   const int tidx = threadIdx.x;
477   const int tidy = threadIdx.y;
478   const int tidz = threadIdx.z;
479   const int blockElem = tidz/BASIS_NCOMP;
480   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
481   const int comp = tidz%BASIS_NCOMP;
482   int dim;
483 
484   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
485        elem += gridDim.x*elemsPerBlock) {
486     if(!transpose) {
487       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
488       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
489       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
490       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
491       dim = 0;
492       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
493       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
494       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
495       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
496       dim = 1;
497       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
498       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
499       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
500       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
501       dim = 2;
502       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
503     } else {
504       dim = 0;
505       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
506       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
507       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
508       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
509       dim = 1;
510       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
511       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
512       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
513       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
514       add(r_V, r_t);
515       dim = 2;
516       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
517       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
518       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
519       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
520       add(r_V, r_t);
521       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
522     }
523   }
524 }
525 
526 /////////////
527 // Kernels //
528 /////////////
529 extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
530                                   const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
531                                   CeedScalar *__restrict__ d_V) {
532   extern __shared__ double slice[];
533   if (BASIS_DIM==1) {
534     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
535   } else if (BASIS_DIM==2) {
536     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
537   } else if (BASIS_DIM==3) {
538     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
539   }
540 }
541 
542 extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
543                                 const CeedScalar *c_B, const CeedScalar *c_G,
544                                 const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) {
545   extern __shared__ double slice[];
546   if (BASIS_DIM==1) {
547     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
548   } else if (BASIS_DIM==2) {
549     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
550   } else if (BASIS_DIM==3) {
551     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
552   }
553 }
554 
555 /////////////
556 // Weights //
557 /////////////
558 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
559                          CeedScalar *w) {
560   const int tid = threadIdx.x;
561   const CeedScalar weight = qweight1d[tid];
562   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
563        elem += gridDim.x*blockDim.y) {
564     const int ind = elem*Q1D + tid;
565     w[ind] = weight;
566   }
567 }
568 
569 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
570                          CeedScalar *w) {
571   const int i = threadIdx.x;
572   const int j = threadIdx.y;
573   const CeedScalar weight = qweight1d[i]*qweight1d[j];
574   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
575        elem += gridDim.x*blockDim.z) {
576     const int ind = elem*Q1D*Q1D + i + j*Q1D;
577     w[ind] = weight;
578   }
579 }
580 
581 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
582                          CeedScalar *w) {
583   const int i = threadIdx.x;
584   const int j = threadIdx.y;
585   const int k = threadIdx.z;
586   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
587   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
588     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
589     w[ind] = weight;
590   }
591 }
592 
593 extern "C" __global__ void weight(const CeedInt nelem,
594                                   const CeedScalar *__restrict__ qweight1d, CeedScalar *__restrict__ v) {
595   if (BASIS_DIM==1) {
596     weight1d(nelem, qweight1d, v);
597   } else if (BASIS_DIM==2) {
598     weight2d(nelem, qweight1d, v);
599   } else if (BASIS_DIM==3) {
600     weight3d(nelem, qweight1d, v);
601   }
602 }
603 
604 );
605 
606 int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
607                        CeedScalar **c_B);
608 int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
609                            CeedInt Q1d, CeedScalar **c_B_ptr, CeedScalar **c_G_ptr);
610 
611 int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem,
612                                      CeedTransposeMode tmode,
613                                      CeedEvalMode emode, CeedVector u, CeedVector v) {
614   int ierr;
615   Ceed ceed;
616   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
617   Ceed_Cuda_shared *ceed_Cuda;
618   CeedGetData(ceed, (void *) &ceed_Cuda); CeedChk(ierr);
619   CeedBasis_Cuda_shared *data;
620   CeedBasisGetData(basis, (void *)&data); CeedChk(ierr);
621   const CeedInt transpose = tmode == CEED_TRANSPOSE;
622   CeedInt dim, ncomp;
623   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
624   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
625 
626   const CeedScalar *d_u;
627   CeedScalar *d_v;
628   if(emode!=CEED_EVAL_WEIGHT) {
629     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
630   }
631   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
632 
633   if (tmode == CEED_TRANSPOSE) {
634     CeedInt length;
635     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
636     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
637   }
638   if (emode == CEED_EVAL_INTERP) {
639     CeedInt P1d, Q1d;
640     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
641     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
642     ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
643     CeedChk(ierr);
644     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, &d_u, &d_v};
645     if (dim==1) {
646       CeedInt elemsPerBlock = 32;
647       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
648                                              ? 1 : 0 );
649       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
650       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, 1,
651                                         elemsPerBlock, sharedMem,
652                                         interpargs);
653       CeedChk(ierr);
654     } else if (dim==2) {
655       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
656       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
657       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
658                                              ? 1 : 0 );
659       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
660       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d,
661                                         ncomp*elemsPerBlock, sharedMem,
662                                         interpargs);
663       CeedChk(ierr);
664     } else if (dim==3) {
665       CeedInt elemsPerBlock = 1;
666       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
667                                              ? 1 : 0 );
668       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
669       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d,
670                                         ncomp*elemsPerBlock, sharedMem,
671                                         interpargs);
672       CeedChk(ierr);
673     }
674   } else if (emode == CEED_EVAL_GRAD) {
675     CeedInt P1d, Q1d;
676     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
677     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
678     ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
679                                   Q1d, &data->c_B, &data->c_G);
680     CeedChk(ierr);
681     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, &data->c_G, &d_u, &d_v};
682     if (dim==1) {
683       CeedInt elemsPerBlock = 32;
684       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
685                                              ? 1 : 0 );
686       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
687       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, 1, elemsPerBlock,
688                                         sharedMem,
689                                         gradargs);
690       CeedChk(ierr);
691     } else if (dim==2) {
692       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
693       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
694       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
695                                              ? 1 : 0 );
696       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
697       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d,
698                                         ncomp*elemsPerBlock, sharedMem,
699                                         gradargs);
700       CeedChk(ierr);
701     } else if (dim==3) {
702       CeedInt elemsPerBlock = 1;
703       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
704                                              ? 1 : 0 );
705       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
706       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d,
707                                         ncomp*elemsPerBlock, sharedMem,
708                                         gradargs);
709       CeedChk(ierr);
710     }
711   } else if (emode == CEED_EVAL_WEIGHT) {
712     CeedInt Q1d;
713     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
714     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
715     if(dim==1) {
716       const CeedInt elemsPerBlock = 32/Q1d;
717       const CeedInt gridsize = nelem/elemsPerBlock + ( (
718                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
719       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, elemsPerBlock, 1,
720                                   weightargs);
721       CeedChk(ierr);
722     } else if(dim==2) {
723       const CeedInt optElems = 32/(Q1d*Q1d);
724       const CeedInt elemsPerBlock = optElems>0?optElems:1;
725       const CeedInt gridsize = nelem/elemsPerBlock + ( (
726                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
727       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d,
728                                   elemsPerBlock, weightargs);
729       CeedChk(ierr);
730     } else if(dim==3) {
731       const CeedInt gridsize = nelem;
732       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
733                                   weightargs);
734       CeedChk(ierr);
735     }
736   }
737 
738   if(emode!=CEED_EVAL_WEIGHT) {
739     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
740   }
741   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
742 
743   return 0;
744 }
745 
746 static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
747   int ierr;
748   Ceed ceed;
749   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
750 
751   CeedBasis_Cuda_shared *data;
752   ierr = CeedBasisGetData(basis, (void *) &data); CeedChk(ierr);
753 
754   CeedChk_Cu(ceed, cuModuleUnload(data->module));
755 
756   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr);
757   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr);
758   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr);
759 
760   ierr = CeedFree(&data); CeedChk(ierr);
761 
762   return 0;
763 }
764 
765 int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
766                                         const CeedScalar *interp1d,
767                                         const CeedScalar *grad1d,
768                                         const CeedScalar *qref1d,
769                                         const CeedScalar *qweight1d,
770                                         CeedBasis basis) {
771   int ierr;
772   Ceed ceed;
773   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
774   if (Q1d<P1d) {
775     return CeedError(ceed, 1, "Backend does not implement underintegrated basis.");
776   }
777   CeedBasis_Cuda_shared *data;
778   ierr = CeedCalloc(1, &data); CeedChk(ierr);
779 
780   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
781   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr);
782   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
783                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
784 
785   const CeedInt iBytes = qBytes * P1d;
786   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr);
787   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
788                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
789 
790   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr);
791   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
792                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
793 
794   CeedInt ncomp;
795   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
796   ierr = CeedCompileCuda(ceed, kernelsShared, &data->module, 7,
797                          "Q1D", Q1d,
798                          "P1D", P1d,
799                          "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
800                              Q1d : P1d, dim),
801                          "BASIS_DIM", dim,
802                          "BASIS_NCOMP", ncomp,
803                          "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
804                          "BASIS_NQPT", CeedIntPow(Q1d, dim)
805                         ); CeedChk(ierr);
806   ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp);
807   CeedChk(ierr);
808   ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad);
809   CeedChk(ierr);
810   ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight);
811   CeedChk(ierr);
812 
813   ierr = CeedBasisSetData(basis, (void *)&data);
814   CeedChk(ierr);
815   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
816                                 CeedBasisApplyTensor_Cuda_shared);
817   CeedChk(ierr);
818   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
819                                 CeedBasisDestroy_Cuda_shared);
820   CeedChk(ierr);
821   return 0;
822 }
823