xref: /libCEED/backends/cuda-shared/ceed-cuda-shared-basis.c (revision 8c91a0c98588ec4c767fca1c07a5e6e035c21fe6)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed-backend.h>
18 #include <ceed.h>
19 #include "ceed-cuda-shared.h"
20 #include "../cuda/ceed-cuda.h"
21 
22 //*********************
23 // shared mem kernels
24 static const char *kernelsShared = QUOTE(
25 
26 inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
27   for (int i = 0; i < Q1D; i++)
28     r_V[i] += r_U[i];
29 }
30 
31 //////////
32 //  1D  //
33 //////////
34 
35 inline __device__ void readDofs1d(const int elem, const int tidx,
36                                   const int tidy, const int tidz,const int comp,
37                                   const int nelem, const CeedScalar *d_U, CeedScalar *slice) {
38   for (int i = 0; i < P1D; i++)
39     slice[i+tidz*Q1D] = d_U[i + comp*P1D + elem*BASIS_NCOMP*P1D];
40   for (int i = P1D; i < Q1D; i++)
41     slice[i+tidz*Q1D] = 0.0;
42 }
43 
44 inline __device__ void writeDofs1d(const int elem, const int tidx,
45                                    const int tidy, const int comp,
46                                    const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
47   if (tidx<P1D) {
48     d_V[tidx + comp*P1D + elem*BASIS_NCOMP*P1D] = r_V;
49   }
50 }
51 
52 inline __device__ void readQuads1d(const int elem, const int tidx,
53                                    const int tidy, const int tidz, const int comp,
54                                    const int dim, const int nelem, const CeedScalar *d_U, CeedScalar *slice) {
55   for (int i = 0; i < Q1D; i++)
56     slice[i+tidz*Q1D] = d_U[i + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D];
57 }
58 
59 inline __device__ void writeQuads1d(const int elem, const int tidx,
60                                     const int tidy, const int comp,
61                                     const int dim, const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
62   d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
63 }
64 
65 inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
66                                    const int tidy, const int tidz,
67                                    const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
68   V = 0.0;
69   for (int i = 0; i < P1D; ++i) {
70     V += B[i + tidx*P1D] * slice[i+tidz*Q1D];//contract x direction
71   }
72 }
73 
74 inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
75     const int tidy, const int tidz,
76     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
77   V = 0.0;
78   for (int i = 0; i < Q1D; ++i) {
79     V += B[tidx + i*P1D] * slice[i+tidz*Q1D];//contract x direction
80   }
81 }
82 
83 inline __device__ void interp1d(const CeedInt nelem, const int transpose,
84                                 const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
85                                 CeedScalar *__restrict__ d_V,
86                                 CeedScalar *slice) {
87   CeedScalar r_V;
88   CeedScalar r_t;
89 
90   const int tidx = threadIdx.x;
91   const int tidy = threadIdx.y;
92   const int tidz = threadIdx.z;
93 
94 
95   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
96        elem += gridDim.x*blockDim.z) {
97     for(int comp=0; comp<BASIS_NCOMP; comp++) {
98       if(!transpose) {
99         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
100         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
101         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
102       } else {
103         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
104         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
105         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
106       }
107     }
108   }
109 }
110 
111 inline __device__ void grad1d(const CeedInt nelem, const int transpose,
112                               const CeedScalar *c_B, const CeedScalar *c_G,
113                               const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V,
114                               CeedScalar *slice) {
115   CeedScalar r_U;
116   CeedScalar r_V;
117 
118   const int tidx = threadIdx.x;
119   const int tidy = threadIdx.y;
120   const int tidz = threadIdx.z;
121   int dim;
122 
123   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
124        elem += gridDim.x*blockDim.z) {
125     for(int comp=0; comp<BASIS_NCOMP; comp++) {
126       if(!transpose) {
127         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
128         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
129         dim = 0;
130         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
131       } else {
132         dim = 0;
133         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
134         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
135         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
136       }
137     }
138   }
139 }
140 //////////
141 //  2D  //
142 //////////
143 
144 inline __device__ void readDofs2d(const int elem, const int tidx,
145                                   const int tidy, const int comp,
146                                   const int nelem, const CeedScalar *d_U, CeedScalar &U) {
147   U = (tidx<P1D
148        && tidy<P1D) ? d_U[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D ] :
149       0.0;
150 }
151 
152 inline __device__ void writeDofs2d(const int elem, const int tidx,
153                                    const int tidy, const int comp,
154                                    const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
155   if (tidx<P1D && tidy<P1D) {
156     d_V[tidx + tidy*P1D + comp*P1D*P1D + elem*BASIS_NCOMP*P1D*P1D ] = r_V;
157   }
158 }
159 
160 inline __device__ void readQuads2d(const int elem, const int tidx,
161                                    const int tidy, const int comp,
162                                    const int dim, const int nelem, const CeedScalar *d_U, CeedScalar &U ) {
163   U = d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
164                dim*BASIS_NCOMP*nelem*Q1D*Q1D];
165 }
166 
167 inline __device__ void writeQuads2d(const int elem, const int tidx,
168                                     const int tidy, const int comp,
169                                     const int dim, const int nelem, const CeedScalar &r_V, CeedScalar *d_V) {
170   d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
171            dim*BASIS_NCOMP*nelem*Q1D*Q1D ] = r_V;
172 }
173 
174 inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
175                                    const int tidy, const int tidz,
176                                    const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
177   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
178   __syncthreads();
179   V = 0.0;
180   for (int i = 0; i < P1D; ++i) {
181     V += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction
182   }
183   __syncthreads();
184 }
185 
186 inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
187                                    const int tidy, const int tidz,
188                                    const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
189   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
190   __syncthreads();
191   V = 0.0;
192   for (int i = 0; i < P1D; ++i) {
193     V += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction
194   }
195   __syncthreads();
196 }
197 
198 inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
199     const int tidy, const int tidz,
200     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
201   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
202   __syncthreads();
203   V = 0.0;
204   if (tidy<P1D) {
205     for (int i = 0; i < Q1D; ++i) {
206       V += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction
207     }
208   }
209   __syncthreads();
210 }
211 
212 inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
213     const int tidy, const int tidz,
214     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
215   slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U;
216   __syncthreads();
217   V = 0.0;
218   if (tidx<P1D) {
219     for (int i = 0; i < Q1D; ++i) {
220       V += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction
221     }
222   }
223   __syncthreads();
224 }
225 
226 inline __device__ void interp2d(const CeedInt nelem, const int transpose,
227                                 const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
228                                 CeedScalar *__restrict__ d_V,
229                                 CeedScalar *slice) {
230   CeedScalar r_V;
231   CeedScalar r_t;
232 
233   const int tidx = threadIdx.x;
234   const int tidy = threadIdx.y;
235   const int tidz = threadIdx.z;
236   const int blockElem = tidz/BASIS_NCOMP;
237   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
238   const int comp = tidz%BASIS_NCOMP;
239 
240   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
241        elem += gridDim.x*elemsPerBlock) {
242     const int comp = tidz%BASIS_NCOMP;
243     r_V = 0.0;
244     r_t = 0.0;
245     if(!transpose) {
246       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
247       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
248       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
249       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
250     } else {
251       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
252       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
253       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
254       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
255     }
256   }
257 }
258 
259 inline __device__ void grad2d(const CeedInt nelem, const int transpose,
260                               const CeedScalar *c_B, const CeedScalar *c_G,
261                               const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V,
262                               CeedScalar *slice) {
263   CeedScalar r_U;
264   CeedScalar r_V;
265   CeedScalar r_t;
266 
267   const int tidx = threadIdx.x;
268   const int tidy = threadIdx.y;
269   const int tidz = threadIdx.z;
270   const int blockElem = tidz/BASIS_NCOMP;
271   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
272   const int comp = tidz%BASIS_NCOMP;
273   int dim;
274 
275   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
276        elem += gridDim.x*elemsPerBlock) {
277     if(!transpose) {
278       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
279       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
280       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
281       dim = 0;
282       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
283       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
284       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
285       dim = 1;
286       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
287     } else {
288       dim = 0;
289       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
290       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
291       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
292       dim = 1;
293       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
294       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
295       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
296       r_V+=r_U;
297       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
298     }
299   }
300 }
301 //////////
302 //  3D  //
303 //////////
304 
305 inline __device__ void readDofs3d(const int elem, const int tidx,
306                                   const int tidy, const int comp,
307                                   const int nelem, const CeedScalar *d_U, CeedScalar *r_U) {
308   for (int i = 0; i < P1D; i++)
309     r_U[i] = (tidx<P1D
310               && tidy<P1D) ? d_U[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
311                                       elem*BASIS_NCOMP*P1D*P1D*P1D ] : 0.0;
312   for (int i = P1D; i < Q1D; i++)
313     r_U[i] = 0.0;
314 }
315 
316 inline __device__ void readQuads3d(const int elem, const int tidx,
317                                    const int tidy, const int comp,
318                                    const int dim, const int nelem, const CeedScalar *d_U, CeedScalar *r_U) {
319   for (int i = 0; i < Q1D; i++)
320     r_U[i] = d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
321                  comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D];
322 }
323 
324 inline __device__ void writeDofs3d(const int elem, const int tidx,
325                                    const int tidy, const int comp,
326                                    const int nelem, const CeedScalar *r_V, CeedScalar *d_V) {
327   if (tidx<P1D && tidy<P1D) {
328     for (int i = 0; i < P1D; i++)
329       d_V[tidx + tidy*P1D + i*P1D*P1D + comp*P1D*P1D*P1D +
330           elem*BASIS_NCOMP*P1D*P1D*P1D ] = r_V[i];
331   }
332 }
333 
334 inline __device__ void writeQuads3d(const int elem, const int tidx,
335                                     const int tidy, const int comp,
336                                     const int dim, const int nelem, const CeedScalar *r_V, CeedScalar *d_V) {
337   for (int i = 0; i < Q1D; i++)
338     d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
339         dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D ] = r_V[i];
340 }
341 
342 inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
343                                    const int tidy, const int tidz,
344                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
345   for (int k = 0; k < P1D; ++k) {
346     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
347     __syncthreads();
348     V[k] = 0.0;
349     for (int i = 0; i < P1D; ++i) {
350       V[k] += B[i + tidx*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction
351     }
352     __syncthreads();
353   }
354 }
355 
356 inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
357                                    const int tidy, const int tidz,
358                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
359   for (int k = 0; k < P1D; ++k) {
360     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
361     __syncthreads();
362     V[k] = 0.0;
363     for (int i = 0; i < P1D; ++i) {
364       V[k] += B[i + tidy*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction
365     }
366     __syncthreads();
367   }
368 }
369 
370 inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
371                                    const int tidy, const int tidz,
372                                    const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
373   for (int k = 0; k < Q1D; ++k) {
374     V[k] = 0.0;
375     for (int i = 0; i < P1D; ++i) {
376       V[k] += B[i + k*P1D] * U[i];//contract z direction
377     }
378   }
379 }
380 
381 inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
382     const int tidy, const int tidz,
383     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
384   for (int k = 0; k < Q1D; ++k) {
385     V[k] = 0.0;
386     if (k<P1D) {
387       for (int i = 0; i < Q1D; ++i) {
388         V[k] += B[k + i*P1D] * U[i];//contract z direction
389       }
390     }
391   }
392 }
393 
394 inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
395     const int tidy, const int tidz,
396     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
397   for (int k = 0; k < P1D; ++k) {
398     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
399     __syncthreads();
400     V[k] = 0.0;
401     if (tidy<P1D) {
402       for (int i = 0; i < Q1D; ++i) {
403         V[k] += B[tidy + i*P1D] * slice[tidx + i*Q1D + tidz*Q1D*Q1D];//contract y direction
404       }
405     }
406     __syncthreads();
407   }
408 }
409 
410 inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
411     const int tidy, const int tidz,
412     const CeedScalar *U, const CeedScalar *B, CeedScalar *V) {
413   for (int k = 0; k < P1D; ++k) {
414     slice[tidx+tidy*Q1D+tidz*Q1D*Q1D] = U[k];
415     __syncthreads();
416     V[k] = 0.0;
417     if (tidx<P1D) {
418       for (int i = 0; i < Q1D; ++i) {
419         V[k] += B[tidx + i*P1D] * slice[i + tidy*Q1D + tidz*Q1D*Q1D];//contract x direction
420       }
421     }
422     __syncthreads();
423   }
424 }
425 
426 inline __device__ void interp3d(const CeedInt nelem, const int transpose,
427                                 const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
428                                 CeedScalar *__restrict__ d_V,
429                                 CeedScalar *slice) {
430   CeedScalar r_V[Q1D];
431   CeedScalar r_t[Q1D];
432 
433   const int tidx = threadIdx.x;
434   const int tidy = threadIdx.y;
435   const int tidz = threadIdx.z;
436   const int blockElem = tidz/BASIS_NCOMP;
437   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
438   const int comp = tidz%BASIS_NCOMP;
439 
440   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
441        elem += gridDim.x*elemsPerBlock) {
442     for (int i = 0; i < Q1D; ++i) {
443       r_V[i] = 0.0;
444       r_t[i] = 0.0;
445     }
446     if(!transpose) {
447       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
448       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
449       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
450       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
451       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
452     } else {
453       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
454       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
455       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
456       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
457       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
458     }
459   }
460 }
461 
462 inline __device__ void grad3d(const CeedInt nelem, const int transpose,
463                               const CeedScalar *c_B, const CeedScalar *c_G,
464                               const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V,
465                               CeedScalar *slice) {
466   //use P1D for one of these
467   CeedScalar r_U[Q1D];
468   CeedScalar r_V[Q1D];
469   CeedScalar r_t[Q1D];
470 
471   const int tidx = threadIdx.x;
472   const int tidy = threadIdx.y;
473   const int tidz = threadIdx.z;
474   const int blockElem = tidz/BASIS_NCOMP;
475   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
476   const int comp = tidz%BASIS_NCOMP;
477   int dim;
478 
479   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
480        elem += gridDim.x*elemsPerBlock) {
481     if(!transpose) {
482       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
483       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
484       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
485       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
486       dim = 0;
487       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
488       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
489       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
490       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
491       dim = 1;
492       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
493       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
494       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
495       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
496       dim = 2;
497       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
498     } else {
499       dim = 0;
500       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
501       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
502       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
503       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
504       dim = 1;
505       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
506       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
507       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
508       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
509       add(r_V, r_t);
510       dim = 2;
511       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
512       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
513       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
514       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
515       add(r_V, r_t);
516       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
517     }
518   }
519 }
520 
521 /////////////
522 // Kernels //
523 /////////////
524 extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
525                                   const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
526                                   CeedScalar *__restrict__ d_V) {
527   extern __shared__ double slice[];
528   if (BASIS_DIM==1) {
529     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
530   } else if (BASIS_DIM==2) {
531     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
532   } else if (BASIS_DIM==3) {
533     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
534   }
535 }
536 
537 extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
538                                 const CeedScalar *c_B, const CeedScalar *c_G,
539                                 const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) {
540   extern __shared__ double slice[];
541   if (BASIS_DIM==1) {
542     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
543   } else if (BASIS_DIM==2) {
544     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
545   } else if (BASIS_DIM==3) {
546     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
547   }
548 }
549 
550 /////////////
551 // Weights //
552 /////////////
553 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
554                          CeedScalar *w) {
555   const int tid = threadIdx.x;
556   const CeedScalar weight = qweight1d[tid];
557   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
558        elem += gridDim.x*blockDim.y) {
559     const int ind = elem*Q1D + tid;
560     w[ind] = weight;
561   }
562 }
563 
564 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
565                          CeedScalar *w) {
566   const int i = threadIdx.x;
567   const int j = threadIdx.y;
568   const CeedScalar weight = qweight1d[i]*qweight1d[j];
569   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
570        elem += gridDim.x*blockDim.z) {
571     const int ind = elem*Q1D*Q1D + i + j*Q1D;
572     w[ind] = weight;
573   }
574 }
575 
576 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
577                          CeedScalar *w) {
578   const int i = threadIdx.x;
579   const int j = threadIdx.y;
580   const int k = threadIdx.z;
581   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
582   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
583     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
584     w[ind] = weight;
585   }
586 }
587 
588 extern "C" __global__ void weight(const CeedInt nelem,
589                                   const CeedScalar *__restrict__ qweight1d, CeedScalar *__restrict__ v) {
590   if (BASIS_DIM==1) {
591     weight1d(nelem, qweight1d, v);
592   } else if (BASIS_DIM==2) {
593     weight2d(nelem, qweight1d, v);
594   } else if (BASIS_DIM==3) {
595     weight3d(nelem, qweight1d, v);
596   }
597 }
598 
599 );
600 
601 int CeedCudaInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
602                        CeedScalar **c_B);
603 int CeedCudaInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
604                            CeedInt Q1d, CeedScalar **c_B_ptr, CeedScalar **c_G_ptr);
605 
606 int CeedBasisApplyTensor_Cuda_shared(CeedBasis basis, const CeedInt nelem,
607                                      CeedTransposeMode tmode,
608                                      CeedEvalMode emode, CeedVector u, CeedVector v) {
609   int ierr;
610   Ceed ceed;
611   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
612   Ceed_Cuda_shared *ceed_Cuda;
613   CeedGetData(ceed, (void *) &ceed_Cuda); CeedChk(ierr);
614   CeedBasis_Cuda_shared *data;
615   CeedBasisGetData(basis, (void *)&data); CeedChk(ierr);
616   const CeedInt transpose = tmode == CEED_TRANSPOSE;
617   CeedInt dim, ncomp;
618   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
619   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
620 
621   const CeedScalar *d_u;
622   CeedScalar *d_v;
623   if(emode!=CEED_EVAL_WEIGHT) {
624     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
625   }
626   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
627 
628   if (tmode == CEED_TRANSPOSE) {
629     CeedInt length;
630     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
631     ierr = cudaMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
632   }
633   if (emode == CEED_EVAL_INTERP) {
634     CeedInt P1d, Q1d;
635     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
636     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
637     ierr = CeedCudaInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
638     CeedChk(ierr);
639     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, &d_u, &d_v};
640     if (dim==1)
641     {
642       CeedInt elemsPerBlock = 32;
643       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
644       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
645       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, 1, elemsPerBlock, sharedMem,
646                             interpargs);
647       CeedChk(ierr);
648     } else if (dim==2) {
649       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
650       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
651       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
652       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
653       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d, ncomp*elemsPerBlock, sharedMem,
654                             interpargs);
655       CeedChk(ierr);
656     } else if (dim==3) {
657       CeedInt elemsPerBlock = 1;
658       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
659       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
660       ierr = CeedRunKernelDimSharedCuda(ceed, data->interp, grid, Q1d, Q1d, ncomp*elemsPerBlock, sharedMem,
661                             interpargs);
662       CeedChk(ierr);
663     }
664   } else if (emode == CEED_EVAL_GRAD) {
665     CeedInt P1d, Q1d;
666     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
667     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
668     ierr = CeedCudaInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
669                                   Q1d, &data->c_B, &data->c_G);
670     CeedChk(ierr);
671     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B, &data->c_G, &d_u, &d_v};
672     if (dim==1)
673     {
674       CeedInt elemsPerBlock = 32;
675       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
676       CeedInt sharedMem = elemsPerBlock*Q1d*sizeof(CeedScalar);
677       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, 1, elemsPerBlock, sharedMem,
678                           gradargs);
679       CeedChk(ierr);
680     } else if (dim==2) {
681       const CeedInt optElems[7] = {0,32,8,6,4,2,8};
682       CeedInt elemsPerBlock = Q1d < 7 ? optElems[Q1d]/ncomp : 1;
683       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
684       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
685       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d, ncomp*elemsPerBlock, sharedMem,
686                           gradargs);
687       CeedChk(ierr);
688     } else if (dim==3) {
689       CeedInt elemsPerBlock = 1;
690       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
691       CeedInt sharedMem = ncomp*elemsPerBlock*Q1d*Q1d*sizeof(CeedScalar);
692       ierr = CeedRunKernelDimSharedCuda(ceed, data->grad, grid, Q1d, Q1d, ncomp*elemsPerBlock, sharedMem,
693                           gradargs);
694       CeedChk(ierr);
695     }
696   } else if (emode == CEED_EVAL_WEIGHT) {
697     CeedInt Q1d;
698     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
699     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
700     if(dim==1){
701       const CeedInt elemsPerBlock = 32/Q1d;
702       const CeedInt gridsize = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
703       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, elemsPerBlock, 1, weightargs);
704       CeedChk(ierr);
705     } else if(dim==2) {
706       const CeedInt optElems = 32/(Q1d*Q1d);
707       const CeedInt elemsPerBlock = optElems>0?optElems:1;
708       const CeedInt gridsize = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
709       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, elemsPerBlock, weightargs);
710       CeedChk(ierr);
711     } else if(dim==3) {
712       const CeedInt gridsize = nelem;
713       ierr = CeedRunKernelDimCuda(ceed, data->weight, gridsize, Q1d, Q1d, Q1d, weightargs);
714       CeedChk(ierr);
715     }
716   }
717 
718   if(emode!=CEED_EVAL_WEIGHT) {
719     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
720   }
721   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
722 
723   return 0;
724 }
725 
726 static int CeedBasisDestroy_Cuda_shared(CeedBasis basis) {
727   int ierr;
728   Ceed ceed;
729   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
730 
731   CeedBasis_Cuda_shared *data;
732   ierr = CeedBasisGetData(basis, (void *) &data); CeedChk(ierr);
733 
734   CeedChk_Cu(ceed, cuModuleUnload(data->module));
735 
736   ierr = cudaFree(data->d_qweight1d); CeedChk_Cu(ceed, ierr);
737   ierr = cudaFree(data->d_interp1d); CeedChk_Cu(ceed, ierr);
738   ierr = cudaFree(data->d_grad1d); CeedChk_Cu(ceed, ierr);
739 
740   ierr = CeedFree(&data); CeedChk(ierr);
741 
742   return 0;
743 }
744 
745 int CeedBasisCreateTensorH1_Cuda_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
746                                         const CeedScalar *interp1d,
747                                         const CeedScalar *grad1d,
748                                         const CeedScalar *qref1d,
749                                         const CeedScalar *qweight1d,
750                                         CeedBasis basis) {
751   int ierr;
752   Ceed ceed;
753   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
754   if (Q1d<P1d)
755   {
756     return CeedError(ceed, 1, "Backend does not implement underintegrated basis.");
757   }
758   CeedBasis_Cuda_shared *data;
759   ierr = CeedCalloc(1, &data); CeedChk(ierr);
760 
761   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
762   ierr = cudaMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Cu(ceed, ierr);
763   ierr = cudaMemcpy(data->d_qweight1d, qweight1d, qBytes,
764                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
765 
766   const CeedInt iBytes = qBytes * P1d;
767   ierr = cudaMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Cu(ceed, ierr);
768   ierr = cudaMemcpy(data->d_interp1d, interp1d, iBytes,
769                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
770 
771   ierr = cudaMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Cu(ceed, ierr);
772   ierr = cudaMemcpy(data->d_grad1d, grad1d, iBytes,
773                     cudaMemcpyHostToDevice); CeedChk_Cu(ceed, ierr);
774 
775   CeedInt ncomp;
776   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
777   ierr = CeedCompileCuda(ceed, kernelsShared, &data->module, 7,
778                          "Q1D", Q1d,
779                          "P1D", P1d,
780                          "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
781                              Q1d : P1d, dim),
782                          "BASIS_DIM", dim,
783                          "BASIS_NCOMP", ncomp,
784                          "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
785                          "BASIS_NQPT", CeedIntPow(Q1d, dim)
786                         ); CeedChk(ierr);
787   ierr = CeedGetKernelCuda(ceed, data->module, "interp", &data->interp);
788   CeedChk(ierr);
789   ierr = CeedGetKernelCuda(ceed, data->module, "grad", &data->grad);
790   CeedChk(ierr);
791   ierr = CeedGetKernelCuda(ceed, data->module, "weight", &data->weight);
792   CeedChk(ierr);
793 
794   ierr = CeedBasisSetData(basis, (void *)&data);
795   CeedChk(ierr);
796   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
797                                 CeedBasisApplyTensor_Cuda_shared);
798   CeedChk(ierr);
799   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
800                                 CeedBasisDestroy_Cuda_shared);
801   CeedChk(ierr);
802   return 0;
803 }
804