xref: /libCEED/backends/hip-shared/ceed-hip-shared-basis.c (revision 7be1e82bee12be3372d0edb3771f3d66d6ac97b8)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include "ceed-hip-shared.h"
18 #include "../hip/ceed-hip-compile.h"
19 
20 //------------------------------------------------------------------------------
21 // Shared mem kernels
22 //------------------------------------------------------------------------------
23 // *INDENT-OFF*
24 static const char *kernelsShared = QUOTE(
25 
26 //------------------------------------------------------------------------------
27 // Sum input into output
28 //------------------------------------------------------------------------------
29 inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
30   for (int i = 0; i < P1D; i++)
31     r_V[i] += r_U[i];
32 }
33 
34 //------------------------------------------------------------------------------
35 // 1D
36 //------------------------------------------------------------------------------
37 
38 //------------------------------------------------------------------------------
39 // Read DoFs
40 //------------------------------------------------------------------------------
41 inline __device__ void readDofs1d(const int elem, const int tidx,
42                                   const int tidy, const int tidz,const int comp,
43                                   const int nelem, const CeedScalar *d_U,
44                                   CeedScalar *slice) {
45   for (int i = 0; i < P1D; i++)
46     slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem];
47   for (int i = P1D; i < Q1D; i++)
48     slice[i + tidz*T1D] = 0.0;
49 }
50 
51 //------------------------------------------------------------------------------
52 // Write DoFs
53 //------------------------------------------------------------------------------
54 inline __device__ void writeDofs1d(const int elem, const int tidx,
55                                    const int tidy, const int comp,
56                                    const int nelem, const CeedScalar &r_V,
57                                    CeedScalar *d_V) {
58   if (tidx<P1D)
59     d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V;
60 }
61 
62 //------------------------------------------------------------------------------
63 // Read quadrature point data
64 //------------------------------------------------------------------------------
65 inline __device__ void readQuads1d(const int elem, const int tidx,
66                                    const int tidy, const int tidz, const int comp,
67                                    const int dim, const int nelem,
68                                    const CeedScalar *d_U, CeedScalar *slice) {
69   for (int i = 0; i < Q1D; i++)
70     slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
71                             dim*BASIS_NCOMP*nelem*Q1D];
72   for (int i = Q1D; i < P1D; i++)
73     slice[i + tidz*T1D] = 0.0;
74 }
75 
76 //------------------------------------------------------------------------------
77 // Write quadrature point data
78 //------------------------------------------------------------------------------
79 inline __device__ void writeQuads1d(const int elem, const int tidx,
80                                     const int tidy, const int comp,
81                                     const int dim, const int nelem,
82                                     const CeedScalar &r_V, CeedScalar *d_V) {
83   if (tidx<Q1D)
84     d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
85 }
86 
87 //------------------------------------------------------------------------------
88 // 1D tensor contraction
89 //------------------------------------------------------------------------------
90 inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
91                                    const int tidy, const int tidz,
92                                    const CeedScalar &U, const CeedScalar *B,
93                                    CeedScalar &V) {
94   V = 0.0;
95   for (int i = 0; i < P1D; ++i)
96     V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction
97 }
98 
99 //------------------------------------------------------------------------------
100 // 1D transpose tensor contraction
101 //------------------------------------------------------------------------------
102 inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
103     const int tidy, const int tidz,
104     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
105   V = 0.0;
106   for (int i = 0; i < Q1D; ++i)
107     V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction
108 }
109 
110 //------------------------------------------------------------------------------
111 // 1D interpolate to quadrature points
112 //------------------------------------------------------------------------------
113 inline __device__ void interp1d(const CeedInt nelem, const int transpose,
114                                 const CeedScalar *c_B,
115                                 const CeedScalar *__restrict__ d_U,
116                                 CeedScalar *__restrict__ d_V,
117                                 CeedScalar *slice) {
118   CeedScalar r_V;
119   CeedScalar r_t;
120 
121   const int tidx = threadIdx.x;
122   const int tidy = threadIdx.y;
123   const int tidz = threadIdx.z;
124 
125 
126   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
127        elem += gridDim.x*blockDim.z) {
128     for (int comp = 0; comp < BASIS_NCOMP; comp++) {
129       if (!transpose) {
130         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
131         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
132         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
133       } else {
134         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
135         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
136         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
137       }
138     }
139   }
140 }
141 
142 //------------------------------------------------------------------------------
143 // 1D derivatives at quadrature points
144 //------------------------------------------------------------------------------
145 inline __device__ void grad1d(const CeedInt nelem, const int transpose,
146                               const CeedScalar *c_B, const CeedScalar *c_G,
147                               const CeedScalar *__restrict__ d_U,
148                               CeedScalar *__restrict__ d_V,
149                               CeedScalar *slice) {
150   CeedScalar r_U;
151   CeedScalar r_V;
152 
153   const int tidx = threadIdx.x;
154   const int tidy = threadIdx.y;
155   const int tidz = threadIdx.z;
156   int dim;
157 
158   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
159        elem += gridDim.x*blockDim.z) {
160     for(int comp = 0; comp < BASIS_NCOMP; comp++) {
161       if (!transpose) {
162         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
163         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
164         dim = 0;
165         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
166       } else {
167         dim = 0;
168         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
169         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
170         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
171       }
172     }
173   }
174 }
175 
176 //------------------------------------------------------------------------------
177 // 1D Quadrature weights
178 //------------------------------------------------------------------------------
179 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
180                          CeedScalar *w) {
181   const int tid = threadIdx.x;
182   const CeedScalar weight = qweight1d[tid];
183   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
184        elem += gridDim.x*blockDim.y) {
185     const int ind = elem*Q1D + tid;
186     w[ind] = weight;
187   }
188 }
189 
190 //------------------------------------------------------------------------------
191 // 2D
192 //------------------------------------------------------------------------------
193 
194 //------------------------------------------------------------------------------
195 // Read DoFs
196 //------------------------------------------------------------------------------
197 inline __device__ void readDofs2d(const int elem, const int tidx,
198                                   const int tidy, const int comp,
199                                   const int nelem, const CeedScalar *d_U,
200                                   CeedScalar &U) {
201   U = (tidx<P1D && tidy<P1D) ?
202       d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0;
203 }
204 
205 //------------------------------------------------------------------------------
206 // Write DoFs
207 //------------------------------------------------------------------------------
208 inline __device__ void writeDofs2d(const int elem, const int tidx,
209                                    const int tidy, const int comp,
210                                    const int nelem, const CeedScalar &r_V,
211                                    CeedScalar *d_V) {
212   if (tidx<P1D && tidy<P1D)
213     d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V;
214 }
215 
216 //------------------------------------------------------------------------------
217 // Read quadrature point data
218 //------------------------------------------------------------------------------
219 inline __device__ void readQuads2d(const int elem, const int tidx,
220                                    const int tidy, const int comp,
221                                    const int dim, const int nelem,
222                                    const CeedScalar *d_U, CeedScalar &U ) {
223   U = (tidx<Q1D && tidy<Q1D) ?
224       d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
225       dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0;
226 }
227 
228 //------------------------------------------------------------------------------
229 // Write quadrature point data
230 //------------------------------------------------------------------------------
231 inline __device__ void writeQuads2d(const int elem, const int tidx,
232                                     const int tidy, const int comp,
233                                     const int dim, const int nelem,
234                                     const CeedScalar &r_V, CeedScalar *d_V) {
235   if (tidx<Q1D && tidy<Q1D)
236     d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
237     dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V;
238 }
239 
240 //------------------------------------------------------------------------------
241 // 2D tensor contraction x
242 //------------------------------------------------------------------------------
243 inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
244                                    const int tidy, const int tidz,
245                                    const CeedScalar &U, const CeedScalar *B,
246                                    CeedScalar &V) {
247   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
248   __syncthreads();
249   V = 0.0;
250   if (tidx < Q1D)
251     for (int i = 0; i < P1D; ++i)
252       V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
253   __syncthreads();
254 }
255 
256 //------------------------------------------------------------------------------
257 // 2D tensor contraction y
258 //------------------------------------------------------------------------------
259 inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
260                                    const int tidy, const int tidz,
261                                    const CeedScalar &U, const CeedScalar *B,
262                                    CeedScalar &V) {
263   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
264   __syncthreads();
265   V = 0.0;
266   if (tidy < Q1D)
267     for (int i = 0; i < P1D; ++i)
268       V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
269   __syncthreads();
270 }
271 
272 //------------------------------------------------------------------------------
273 // 2D transpose tensor contraction y
274 //------------------------------------------------------------------------------
275 inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
276     const int tidy, const int tidz,
277     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
278   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
279   __syncthreads();
280   V = 0.0;
281   if (tidy < P1D)
282     for (int i = 0; i < Q1D; ++i)
283       V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
284   __syncthreads();
285 }
286 
287 //------------------------------------------------------------------------------
288 // 2D transpose tensor contraction x
289 //------------------------------------------------------------------------------
290 inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
291     const int tidy, const int tidz,
292     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
293   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
294   __syncthreads();
295   V = 0.0;
296   if (tidx < P1D)
297     for (int i = 0; i < Q1D; ++i)
298       V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
299   __syncthreads();
300 }
301 
302 //------------------------------------------------------------------------------
303 // 2D interpolate to quadrature points
304 //------------------------------------------------------------------------------
305 inline __device__ void interp2d(const CeedInt nelem, const int transpose,
306                                 const CeedScalar *c_B,
307                                 const CeedScalar *__restrict__ d_U,
308                                 CeedScalar *__restrict__ d_V,
309                                 CeedScalar *slice) {
310   CeedScalar r_V;
311   CeedScalar r_t;
312 
313   const int tidx = threadIdx.x;
314   const int tidy = threadIdx.y;
315   const int tidz = threadIdx.z;
316   const int blockElem = tidz/BASIS_NCOMP;
317   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
318   const int comp = tidz%BASIS_NCOMP;
319 
320   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
321        elem += gridDim.x*elemsPerBlock) {
322     const int comp = tidz%BASIS_NCOMP;
323     r_V = 0.0;
324     r_t = 0.0;
325     if (!transpose) {
326       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
327       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
328       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
329       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
330     } else {
331       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
332       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
333       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
334       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
335     }
336   }
337 }
338 
339 //------------------------------------------------------------------------------
340 // 2D derivatives at quadrature points
341 //------------------------------------------------------------------------------
342 inline __device__ void grad2d(const CeedInt nelem, const int transpose,
343                               const CeedScalar *c_B, const CeedScalar *c_G,
344                               const CeedScalar *__restrict__ d_U,
345                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
346   CeedScalar r_U;
347   CeedScalar r_V;
348   CeedScalar r_t;
349 
350   const int tidx = threadIdx.x;
351   const int tidy = threadIdx.y;
352   const int tidz = threadIdx.z;
353   const int blockElem = tidz/BASIS_NCOMP;
354   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
355   const int comp = tidz%BASIS_NCOMP;
356   int dim;
357 
358   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
359        elem += gridDim.x*elemsPerBlock) {
360     if (!transpose) {
361       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
362       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
363       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
364       dim = 0;
365       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
366       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
367       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
368       dim = 1;
369       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
370     } else {
371       dim = 0;
372       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
373       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
374       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
375       dim = 1;
376       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
377       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
378       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
379       r_V += r_U;
380       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
381     }
382   }
383 }
384 
385 //------------------------------------------------------------------------------
386 // 2D quadrature weights
387 //------------------------------------------------------------------------------
388 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
389                          CeedScalar *w) {
390   const int i = threadIdx.x;
391   const int j = threadIdx.y;
392   const CeedScalar weight = qweight1d[i]*qweight1d[j];
393   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
394        elem += gridDim.x*blockDim.z) {
395     const int ind = elem*Q1D*Q1D + i + j*Q1D;
396     w[ind] = weight;
397   }
398 }
399 
400 //------------------------------------------------------------------------------
401 // 3D
402 //------------------------------------------------------------------------------
403 
404 //------------------------------------------------------------------------------
405 // Read DoFs
406 //------------------------------------------------------------------------------
407 inline __device__ void readDofs3d(const int elem, const int tidx,
408                                   const int tidy, const int comp,
409                                   const int nelem, const CeedScalar *d_U,
410                                   CeedScalar *r_U) {
411   for (int i = 0; i < P1D; i++)
412     r_U[i] = (tidx < P1D && tidy < P1D) ?
413               d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
414                   comp*P1D*P1D*P1D*nelem] : 0.0;
415   for (int i = P1D; i < Q1D; i++)
416     r_U[i] = 0.0;
417 }
418 
419 //------------------------------------------------------------------------------
420 // Write DoFs
421 //------------------------------------------------------------------------------
422 inline __device__ void writeDofs3d(const int elem, const int tidx,
423                                    const int tidy, const int comp,
424                                    const int nelem, const CeedScalar *r_V,
425                                    CeedScalar *d_V) {
426   if (tidx < P1D && tidy < P1D) {
427     for (int i = 0; i < P1D; i++)
428       d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
429           comp*P1D*P1D*P1D*nelem] = r_V[i];
430   }
431 }
432 
433 //------------------------------------------------------------------------------
434 // Read quadrature point data
435 //------------------------------------------------------------------------------
436 inline __device__ void readQuads3d(const int elem, const int tidx,
437                                    const int tidy, const int comp,
438                                    const int dim, const int nelem,
439                                    const CeedScalar *d_U, CeedScalar *r_U) {
440   for (int i = 0; i < Q1D; i++)
441     r_U[i] = (tidx < Q1D && tidy < Q1D) ?
442               d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
443               comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0;
444   for (int i = Q1D; i < P1D; i++)
445     r_U[i] = 0.0;
446 }
447 
448 //------------------------------------------------------------------------------
449 // Write quadrature point data
450 //------------------------------------------------------------------------------
451 inline __device__ void writeQuads3d(const int elem, const int tidx,
452                                     const int tidy, const int comp,
453                                     const int dim, const int nelem,
454                                     const CeedScalar *r_V, CeedScalar *d_V) {
455   if (tidx < Q1D && tidy < Q1D) {
456     for (int i = 0; i < Q1D; i++)
457       d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
458           dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i];
459   }
460 }
461 
462 //------------------------------------------------------------------------------
463 // 3D tensor contract x
464 //------------------------------------------------------------------------------
465 inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
466                                    const int tidy, const int tidz,
467                                    const CeedScalar *U,
468                                    const CeedScalar *B,
469                                    CeedScalar *V) {
470   for (int k = 0; k < P1D; ++k) {
471     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
472     __syncthreads();
473     V[k] = 0.0;
474     if (tidx < Q1D && tidy < P1D)
475       for (int i = 0; i < P1D; ++i)
476         V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
477     __syncthreads();
478   }
479 }
480 
481 //------------------------------------------------------------------------------
482 // 3D tensor contract y
483 //------------------------------------------------------------------------------
484 inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
485                                    const int tidy, const int tidz,
486                                    const CeedScalar *U,
487                                    const CeedScalar *B,
488                                    CeedScalar *V) {
489   for (int k = 0; k < P1D; ++k) {
490     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
491     __syncthreads();
492     V[k] = 0.0;
493     if (tidx < Q1D && tidy < Q1D)
494       for (int i = 0; i < P1D; ++i)
495         V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
496     __syncthreads();
497   }
498 }
499 
500 //------------------------------------------------------------------------------
501 // 3D tensor contract z
502 //------------------------------------------------------------------------------
503 inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
504                                    const int tidy, const int tidz,
505                                    const CeedScalar *U,
506                                    const CeedScalar *B,
507                                    CeedScalar *V) {
508   for (int k = 0; k < Q1D; ++k) {
509     V[k] = 0.0;
510     if (tidx < Q1D && tidy < Q1D)
511       for (int i = 0; i < P1D; ++i)
512         V[k] += B[i + k*P1D] * U[i]; // Contract z direction
513   }
514   for (int k = Q1D; k < P1D; ++k)
515     V[k] = 0.0;
516 }
517 
518 //------------------------------------------------------------------------------
519 // 3D transpose tensor contract z
520 //------------------------------------------------------------------------------
521 inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
522                                             const int tidy, const int tidz,
523                                             const CeedScalar *U,
524                                             const CeedScalar *B,
525                                             CeedScalar *V) {
526   for (int k = 0; k < P1D; ++k) {
527     V[k] = 0.0;
528     if (tidx < Q1D && tidy < Q1D)
529       for (int i = 0; i < Q1D; ++i)
530         V[k] += B[k + i*P1D] * U[i]; // Contract z direction
531   }
532   for (int k = P1D; k < Q1D; ++k)
533     V[k] = 0.0;
534 }
535 
536 //------------------------------------------------------------------------------
537 // 3D transpose tensor contract y
538 //------------------------------------------------------------------------------
539 inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
540                                             const int tidy, const int tidz,
541                                             const CeedScalar *U,
542                                             const CeedScalar *B,
543                                             CeedScalar *V) {
544   for (int k = 0; k < P1D; ++k) {
545     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
546     __syncthreads();
547     V[k] = 0.0;
548     if (tidx < Q1D && tidy < P1D)
549       for (int i = 0; i < Q1D; ++i)
550         V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
551     __syncthreads();
552   }
553 }
554 
555 //------------------------------------------------------------------------------
556 // 3D transpose tensor contract x
557 //------------------------------------------------------------------------------
558 inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
559                                             const int tidy, const int tidz,
560                                             const CeedScalar *U,
561                                             const CeedScalar *B,
562                                             CeedScalar *V) {
563   for (int k = 0; k < P1D; ++k) {
564     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
565     __syncthreads();
566     V[k] = 0.0;
567     if (tidx < P1D && tidy < P1D)
568       for (int i = 0; i < Q1D; ++i)
569         V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
570     __syncthreads();
571   }
572 }
573 
574 //------------------------------------------------------------------------------
575 // 3D interpolate to quadrature points
576 //------------------------------------------------------------------------------
577 inline __device__ void interp3d(const CeedInt nelem, const int transpose,
578                                 const CeedScalar *c_B,
579                                 const CeedScalar *__restrict__ d_U,
580                                 CeedScalar *__restrict__ d_V,
581                                 CeedScalar *slice) {
582   CeedScalar r_V[T1D];
583   CeedScalar r_t[T1D];
584 
585   const int tidx = threadIdx.x;
586   const int tidy = threadIdx.y;
587   const int tidz = threadIdx.z;
588   const int blockElem = tidz/BASIS_NCOMP;
589   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
590   const int comp = tidz%BASIS_NCOMP;
591 
592   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
593        elem += gridDim.x*elemsPerBlock) {
594     for (int i = 0; i < T1D; ++i) {
595       r_V[i] = 0.0;
596       r_t[i] = 0.0;
597     }
598     if (!transpose) {
599       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
600       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
601       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
602       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
603       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
604     } else {
605       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
606       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
607       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
608       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
609       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
610     }
611   }
612 }
613 
614 //------------------------------------------------------------------------------
615 // 3D derivatives at quadrature points
616 //------------------------------------------------------------------------------
617 inline __device__ void grad3d(const CeedInt nelem, const int transpose,
618                               const CeedScalar *c_B, const CeedScalar *c_G,
619                               const CeedScalar *__restrict__ d_U,
620                               CeedScalar *__restrict__ d_V,
621                               CeedScalar *slice) {
622   // Use P1D for one of these
623   CeedScalar r_U[T1D];
624   CeedScalar r_V[T1D];
625   CeedScalar r_t[T1D];
626 
627   const int tidx = threadIdx.x;
628   const int tidy = threadIdx.y;
629   const int tidz = threadIdx.z;
630   const int blockElem = tidz/BASIS_NCOMP;
631   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
632   const int comp = tidz%BASIS_NCOMP;
633   int dim;
634 
635   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
636        elem += gridDim.x*elemsPerBlock) {
637     for (int i = 0; i < T1D; ++i) {
638       r_U[i] = 0.0;
639       r_V[i] = 0.0;
640       r_t[i] = 0.0;
641     }
642     if (!transpose) {
643       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
644       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
645       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
646       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
647       dim = 0;
648       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
649       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
650       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
651       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
652       dim = 1;
653       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
654       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
655       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
656       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
657       dim = 2;
658       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
659     } else {
660       dim = 0;
661       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
662       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
663       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
664       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
665       dim = 1;
666       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
667       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
668       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
669       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
670       add(r_V, r_t);
671       dim = 2;
672       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
673       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
674       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
675       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
676       add(r_V, r_t);
677       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
678     }
679   }
680 }
681 
682 //------------------------------------------------------------------------------
683 // 3D quadrature weights
684 //------------------------------------------------------------------------------
685 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
686                          CeedScalar *w) {
687   const int i = threadIdx.x;
688   const int j = threadIdx.y;
689   const int k = threadIdx.z;
690   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
691   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
692     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
693     w[ind] = weight;
694   }
695 }
696 
697 
698 //------------------------------------------------------------------------------
699 // Basis kernels
700 //------------------------------------------------------------------------------
701 
702 //------------------------------------------------------------------------------
703 // Interp kernel by dim
704 //------------------------------------------------------------------------------
705 extern "C" __launch_bounds__(INTERP_BLKSIZE) __global__ void interp(
706                                   const CeedInt nelem, const int transpose,
707                                   const CeedScalar *c_B,
708                                   const CeedScalar *__restrict__ d_U,
709                                   CeedScalar *__restrict__ d_V) {
710   HIP_DYNAMIC_SHARED( double, slice)
711   if (BASIS_DIM == 1) {
712     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
713   } else if (BASIS_DIM == 2) {
714     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
715   } else if (BASIS_DIM == 3) {
716     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
717   }
718 }
719 
720 //------------------------------------------------------------------------------
721 // Grad kernel by dim
722 //------------------------------------------------------------------------------
723 extern "C" __launch_bounds__(GRAD_BLKSIZE) __global__ void grad(const CeedInt nelem,
724                                 const int transpose,
725                                 const CeedScalar *c_B, const CeedScalar *c_G,
726                                 const CeedScalar *__restrict__ d_U,
727                                 CeedScalar *__restrict__ d_V) {
728   HIP_DYNAMIC_SHARED( double, slice)
729   if (BASIS_DIM == 1) {
730     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
731   } else if (BASIS_DIM == 2) {
732     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
733   } else if (BASIS_DIM == 3) {
734     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
735   }
736 }
737 
738 //------------------------------------------------------------------------------
739 // Weight kernels by dim
740 //------------------------------------------------------------------------------
741 extern "C" __launch_bounds__(WEIGHT_BLKSIZE) __global__ void weight(const CeedInt nelem,
742                                   const CeedScalar *__restrict__ qweight1d,
743                                   CeedScalar *__restrict__ v) {
744   if (BASIS_DIM == 1) {
745     weight1d(nelem, qweight1d, v);
746   } else if (BASIS_DIM == 2) {
747     weight2d(nelem, qweight1d, v);
748   } else if (BASIS_DIM == 3) {
749     weight3d(nelem, qweight1d, v);
750   }
751 }
752 
753 );
754 // *INDENT-ON*
755 
756 //------------------------------------------------------------------------------
757 // Compute a block size based on required minimum threads
758 //------------------------------------------------------------------------------
759 static CeedInt ComputeBlockSizeFromRequirement(const CeedInt required) {
760   CeedInt maxSize = 1024;    // Max total threads per block
761   CeedInt currentSize = 64;  // Start with one group
762 
763   while(currentSize < maxSize) {
764     if (currentSize > required)
765       break;
766     else
767       currentSize = currentSize * 2;
768   }
769   return currentSize;
770 }
771 
772 //------------------------------------------------------------------------------
773 // Compute required thread block sizes for basis kernels given P, Q, dim, and
774 // ncomp
775 //------------------------------------------------------------------------------
776 static int ComputeBasisThreadBlockSizes(const CeedInt dim, const CeedInt P1d,
777                                         const CeedInt Q1d,
778                                         const CeedInt ncomp, CeedInt *blksizes) {
779 
780   // Note that this will use the same block sizes for all dimensions when compiling,
781   // but as each basis object is defined for a particular dimension, we will never
782   // call any kernels except the ones for the dimension for which we have computed the
783   // block sizes.
784   const CeedInt thread1d = CeedIntMax(P1d, Q1d);
785   switch (dim) {
786   case 1: {
787     // Interp kernels:
788     blksizes[0] = 256;
789 
790     // Grad kernels:
791     blksizes[1] = 256;
792 
793     // Weight kernels:
794     blksizes[2] = 256;
795 
796   } break;
797   case 2: {
798     // Interp kernels:
799     CeedInt required = thread1d * thread1d * ncomp;
800     blksizes[0]  = ComputeBlockSizeFromRequirement(required);
801 
802     // Grad kernels: currently use same required minimum threads
803     blksizes[1]  = ComputeBlockSizeFromRequirement(required);
804 
805     // Weight kernels:
806     required = CeedIntMax(64, Q1d * Q1d);
807     blksizes[2]  = ComputeBlockSizeFromRequirement(required);
808 
809   } break;
810   case 3: {
811     // Interp kernels:
812     CeedInt required = thread1d * thread1d * ncomp;
813     blksizes[0]  = ComputeBlockSizeFromRequirement(required);
814 
815     // Grad kernels: currently use same required minimum threads
816     blksizes[1]  = ComputeBlockSizeFromRequirement(required);
817 
818     // Weight kernels:
819     required = Q1d * Q1d * Q1d;
820     blksizes[2]  = ComputeBlockSizeFromRequirement(required);
821   }
822   }
823 
824   return 0;
825 }
826 
827 //------------------------------------------------------------------------------
828 // Apply basis
829 //------------------------------------------------------------------------------
830 int CeedBasisApplyTensor_Hip_shared(CeedBasis basis, const CeedInt nelem,
831                                     CeedTransposeMode tmode,
832                                     CeedEvalMode emode, CeedVector u,
833                                     CeedVector v) {
834   int ierr;
835   Ceed ceed;
836   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
837   Ceed_Hip_shared *ceed_Hip;
838   CeedGetData(ceed, &ceed_Hip); CeedChk(ierr);
839   CeedBasis_Hip_shared *data;
840   CeedBasisGetData(basis, &data); CeedChk(ierr);
841   const CeedInt transpose = tmode == CEED_TRANSPOSE;
842   CeedInt dim, ncomp;
843   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
844   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
845 
846   // Read vectors
847   const CeedScalar *d_u;
848   CeedScalar *d_v;
849   if (emode != CEED_EVAL_WEIGHT) {
850     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
851   }
852   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
853 
854   // Clear v for transpose mode
855   if (tmode == CEED_TRANSPOSE) {
856     CeedInt length;
857     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
858     ierr = hipMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
859   }
860 
861   // Apply basis operation
862   switch (emode) {
863   case CEED_EVAL_INTERP: {
864     CeedInt P1d, Q1d;
865     CeedInt blksize = data->blksizes[0];
866     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
867     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
868     CeedInt thread1d = CeedIntMax(Q1d, P1d);
869     ierr = CeedHipInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
870     CeedChk(ierr);
871     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
872                           &d_u, &d_v
873                          };
874     if (dim == 1) {
875       CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64;
876       elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1;
877       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
878                                              ? 1 : 0 );
879       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
880       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, 1,
881                                        elemsPerBlock, sharedMem,
882                                        interpargs); CeedChk(ierr);
883     } else if (dim == 2) {
884       // Check if required threads is small enough to do multiple elems
885       const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1);
886       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
887                                              ? 1 : 0 );
888       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
889       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d,
890                                        ncomp*elemsPerBlock, sharedMem,
891                                        interpargs); CeedChk(ierr);
892     } else if (dim == 3) {
893       CeedInt elemsPerBlock = 1;
894       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
895                                              ? 1 : 0 );
896       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
897       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d,
898                                        ncomp*elemsPerBlock, sharedMem,
899                                        interpargs); CeedChk(ierr);
900     }
901   } break;
902   case CEED_EVAL_GRAD: {
903     CeedInt P1d, Q1d;
904     CeedInt blksize = data->blksizes[1];
905     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
906     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
907     CeedInt thread1d = CeedIntMax(Q1d, P1d);
908     ierr = CeedHipInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
909                                  Q1d, &data->c_B, &data->c_G);
910     CeedChk(ierr);
911     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
912                         &data->c_G, &d_u, &d_v
913                        };
914     if (dim == 1) {
915       CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64;
916       elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1;
917       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
918                                              ? 1 : 0 );
919       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
920       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, 1,
921                                        elemsPerBlock, sharedMem, gradargs);
922       CeedChk(ierr);
923     } else if (dim == 2) {
924       // Check if required threads is small enough to do multiple elems
925       const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1);
926       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
927                                              ? 1 : 0 );
928       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
929       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d,
930                                        ncomp*elemsPerBlock, sharedMem,
931                                        gradargs); CeedChk(ierr);
932     } else if (dim == 3) {
933       CeedInt elemsPerBlock = 1;
934       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
935                                              ? 1 : 0 );
936       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
937       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d,
938                                        ncomp*elemsPerBlock, sharedMem,
939                                        gradargs); CeedChk(ierr);
940     }
941   } break;
942   case CEED_EVAL_WEIGHT: {
943     CeedInt Q1d;
944     CeedInt blksize = data->blksizes[2];
945     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
946     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
947     if (dim == 1) {
948       const CeedInt optElems = blksize/Q1d;
949       const CeedInt elemsPerBlock = optElems>0?optElems:1;
950       const CeedInt gridsize = nelem/elemsPerBlock + ( (
951                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
952       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d,
953                                  elemsPerBlock, 1, weightargs);
954       CeedChk(ierr);
955     } else if (dim == 2) {
956       const CeedInt optElems = blksize/(Q1d*Q1d);
957       const CeedInt elemsPerBlock = optElems>0?optElems:1;
958       const CeedInt gridsize = nelem/elemsPerBlock + ( (
959                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
960       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d,
961                                  elemsPerBlock, weightargs);
962       CeedChk(ierr);
963     } else if (dim == 3) {
964       const CeedInt gridsize = nelem;
965       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
966                                  weightargs);
967       CeedChk(ierr);
968     }
969   } break;
970   // LCOV_EXCL_START
971   // Evaluate the divergence to/from the quadrature points
972   case CEED_EVAL_DIV:
973     return CeedError(ceed, 1, "CEED_EVAL_DIV not supported");
974   // Evaluate the curl to/from the quadrature points
975   case CEED_EVAL_CURL:
976     return CeedError(ceed, 1, "CEED_EVAL_CURL not supported");
977   // Take no action, BasisApply should not have been called
978   case CEED_EVAL_NONE:
979     return CeedError(ceed, 1,
980                      "CEED_EVAL_NONE does not make sense in this context");
981     // LCOV_EXCL_STOP
982   }
983 
984   // Restore vectors
985   if (emode != CEED_EVAL_WEIGHT) {
986     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
987   }
988   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
989   return 0;
990 }
991 
992 //------------------------------------------------------------------------------
993 // Destroy basis
994 //------------------------------------------------------------------------------
995 static int CeedBasisDestroy_Hip_shared(CeedBasis basis) {
996   int ierr;
997   Ceed ceed;
998   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
999 
1000   CeedBasis_Hip_shared *data;
1001   ierr = CeedBasisGetData(basis, &data); CeedChk(ierr);
1002 
1003   CeedChk_Hip(ceed, hipModuleUnload(data->module));
1004 
1005   ierr = hipFree(data->d_qweight1d); CeedChk_Hip(ceed, ierr);
1006   ierr = hipFree(data->d_interp1d); CeedChk_Hip(ceed, ierr);
1007   ierr = hipFree(data->d_grad1d); CeedChk_Hip(ceed, ierr);
1008   ierr = hipFree(data->d_collograd1d); CeedChk_Hip(ceed, ierr);
1009 
1010   ierr = CeedFree(&data); CeedChk(ierr);
1011 
1012   return 0;
1013 }
1014 
1015 //------------------------------------------------------------------------------
1016 // Create tensor basis
1017 //------------------------------------------------------------------------------
1018 int CeedBasisCreateTensorH1_Hip_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
1019                                        const CeedScalar *interp1d,
1020                                        const CeedScalar *grad1d,
1021                                        const CeedScalar *qref1d,
1022                                        const CeedScalar *qweight1d,
1023                                        CeedBasis basis) {
1024   int ierr;
1025   Ceed ceed;
1026   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
1027   CeedBasis_Hip_shared *data;
1028   ierr = CeedCalloc(1, &data); CeedChk(ierr);
1029 
1030   // Copy basis data to GPU
1031   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
1032   ierr = hipMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Hip(ceed, ierr);
1033   ierr = hipMemcpy(data->d_qweight1d, qweight1d, qBytes,
1034                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1035 
1036   const CeedInt iBytes = qBytes * P1d;
1037   ierr = hipMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Hip(ceed, ierr);
1038   ierr = hipMemcpy(data->d_interp1d, interp1d, iBytes,
1039                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1040 
1041   ierr = hipMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Hip(ceed, ierr);
1042   ierr = hipMemcpy(data->d_grad1d, grad1d, iBytes,
1043                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1044 
1045   // Compute collocated gradient and copy to GPU
1046   data->d_collograd1d = NULL;
1047   if (dim == 3 && Q1d >= P1d) {
1048     CeedScalar *collograd1d;
1049     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChk(ierr);
1050     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChk(ierr);
1051     ierr = hipMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
1052     CeedChk_Hip(ceed, ierr);
1053     ierr = hipMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
1054                      hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1055     ierr = CeedFree(&collograd1d); CeedChk(ierr);
1056   }
1057 
1058   // Set number of threads per block for basis kernels
1059   CeedInt ncomp;
1060   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
1061   ierr = ComputeBasisThreadBlockSizes(dim, P1d, Q1d, ncomp, data->blksizes);
1062   CeedChk(ierr);
1063 
1064   // Compile basis kernels
1065   ierr = CeedCompileHip(ceed, kernelsShared, &data->module, 11,
1066                         "Q1D", Q1d,
1067                         "P1D", P1d,
1068                         "T1D", CeedIntMax(Q1d, P1d),
1069                         "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
1070                             Q1d : P1d, dim),
1071                         "BASIS_DIM", dim,
1072                         "BASIS_NCOMP", ncomp,
1073                         "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
1074                         "BASIS_NQPT", CeedIntPow(Q1d, dim),
1075                         "INTERP_BLKSIZE", data->blksizes[0],
1076                         "GRAD_BLKSIZE", data->blksizes[1],
1077                         "WEIGHT_BLKSIZE", data->blksizes[2]
1078                        ); CeedChk(ierr);
1079   ierr = CeedGetKernelHip(ceed, data->module, "interp", &data->interp);
1080   CeedChk(ierr);
1081   ierr = CeedGetKernelHip(ceed, data->module, "grad", &data->grad);
1082   CeedChk(ierr);
1083   ierr = CeedGetKernelHip(ceed, data->module, "weight", &data->weight);
1084   CeedChk(ierr);
1085 
1086   ierr = CeedBasisSetData(basis, data); CeedChk(ierr);
1087 
1088   // Register backend functions
1089   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
1090                                 CeedBasisApplyTensor_Hip_shared);
1091   CeedChk(ierr);
1092   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
1093                                 CeedBasisDestroy_Hip_shared); CeedChk(ierr);
1094   return 0;
1095 }
1096 //------------------------------------------------------------------------------
1097