xref: /libCEED/backends/hip-shared/ceed-hip-shared-basis.c (revision 602cc54da33f8971b275df2b9b5f2ae2793f4bc5)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include "ceed-hip-shared.h"
18 #include "../hip/ceed-hip-compile.h"
19 
20 //------------------------------------------------------------------------------
21 // Shared mem kernels
22 //------------------------------------------------------------------------------
23 // *INDENT-OFF*
24 static const char *kernelsShared = QUOTE(
25 
26 //------------------------------------------------------------------------------
27 // Sum input into output
28 //------------------------------------------------------------------------------
29 inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
30   for (int i = 0; i < P1D; i++)
31     r_V[i] += r_U[i];
32 }
33 
34 //------------------------------------------------------------------------------
35 // 1D
36 //------------------------------------------------------------------------------
37 
38 //------------------------------------------------------------------------------
39 // Read DoFs
40 //------------------------------------------------------------------------------
41 inline __device__ void readDofs1d(const int elem, const int tidx,
42                                   const int tidy, const int tidz,const int comp,
43                                   const int nelem, const CeedScalar *d_U,
44                                   CeedScalar *slice) {
45   for (int i = 0; i < P1D; i++)
46     slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem];
47   for (int i = P1D; i < Q1D; i++)
48     slice[i + tidz*T1D] = 0.0;
49 }
50 
51 //------------------------------------------------------------------------------
52 // Write DoFs
53 //------------------------------------------------------------------------------
54 inline __device__ void writeDofs1d(const int elem, const int tidx,
55                                    const int tidy, const int comp,
56                                    const int nelem, const CeedScalar &r_V,
57                                    CeedScalar *d_V) {
58   if (tidx<P1D)
59     d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V;
60 }
61 
62 //------------------------------------------------------------------------------
63 // Read quadrature point data
64 //------------------------------------------------------------------------------
65 inline __device__ void readQuads1d(const int elem, const int tidx,
66                                    const int tidy, const int tidz, const int comp,
67                                    const int dim, const int nelem,
68                                    const CeedScalar *d_U, CeedScalar *slice) {
69   for (int i = 0; i < Q1D; i++)
70     slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
71                             dim*BASIS_NCOMP*nelem*Q1D];
72   for (int i = Q1D; i < P1D; i++)
73     slice[i + tidz*T1D] = 0.0;
74 }
75 
76 //------------------------------------------------------------------------------
77 // Write quadrature point data
78 //------------------------------------------------------------------------------
79 inline __device__ void writeQuads1d(const int elem, const int tidx,
80                                     const int tidy, const int comp,
81                                     const int dim, const int nelem,
82                                     const CeedScalar &r_V, CeedScalar *d_V) {
83   if (tidx<Q1D)
84     d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
85 }
86 
87 //------------------------------------------------------------------------------
88 // 1D tensor contraction
89 //------------------------------------------------------------------------------
90 inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
91                                    const int tidy, const int tidz,
92                                    const CeedScalar &U, const CeedScalar *B,
93                                    CeedScalar &V) {
94   V = 0.0;
95   for (int i = 0; i < P1D; ++i)
96     V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction
97 }
98 
99 //------------------------------------------------------------------------------
100 // 1D transpose tensor contraction
101 //------------------------------------------------------------------------------
102 inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
103     const int tidy, const int tidz,
104     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
105   V = 0.0;
106   for (int i = 0; i < Q1D; ++i)
107     V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction
108 }
109 
110 //------------------------------------------------------------------------------
111 // 1D interpolate to quadrature points
112 //------------------------------------------------------------------------------
113 inline __device__ void interp1d(const CeedInt nelem, const int transpose,
114                                 const CeedScalar *c_B,
115                                 const CeedScalar *__restrict__ d_U,
116                                 CeedScalar *__restrict__ d_V,
117                                 CeedScalar *slice) {
118   CeedScalar r_V;
119   CeedScalar r_t;
120 
121   const int tidx = threadIdx.x;
122   const int tidy = threadIdx.y;
123   const int tidz = threadIdx.z;
124 
125 
126   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
127        elem += gridDim.x*blockDim.z) {
128     for (int comp = 0; comp < BASIS_NCOMP; comp++) {
129       if (!transpose) {
130         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
131         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
132         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
133       } else {
134         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
135         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
136         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
137       }
138     }
139   }
140 }
141 
142 //------------------------------------------------------------------------------
143 // 1D derivatives at quadrature points
144 //------------------------------------------------------------------------------
145 inline __device__ void grad1d(const CeedInt nelem, const int transpose,
146                               const CeedScalar *c_B, const CeedScalar *c_G,
147                               const CeedScalar *__restrict__ d_U,
148                               CeedScalar *__restrict__ d_V,
149                               CeedScalar *slice) {
150   CeedScalar r_U;
151   CeedScalar r_V;
152 
153   const int tidx = threadIdx.x;
154   const int tidy = threadIdx.y;
155   const int tidz = threadIdx.z;
156   int dim;
157 
158   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
159        elem += gridDim.x*blockDim.z) {
160     for(int comp = 0; comp < BASIS_NCOMP; comp++) {
161       if (!transpose) {
162         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
163         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
164         dim = 0;
165         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
166       } else {
167         dim = 0;
168         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
169         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
170         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
171       }
172     }
173   }
174 }
175 
176 //------------------------------------------------------------------------------
177 // 1D Quadrature weights
178 //------------------------------------------------------------------------------
179 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
180                          CeedScalar *w) {
181   const int tid = threadIdx.x;
182   const CeedScalar weight = qweight1d[tid];
183   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
184        elem += gridDim.x*blockDim.y) {
185     const int ind = elem*Q1D + tid;
186     w[ind] = weight;
187   }
188 }
189 
190 //------------------------------------------------------------------------------
191 // 2D
192 //------------------------------------------------------------------------------
193 
194 //------------------------------------------------------------------------------
195 // Read DoFs
196 //------------------------------------------------------------------------------
197 inline __device__ void readDofs2d(const int elem, const int tidx,
198                                   const int tidy, const int comp,
199                                   const int nelem, const CeedScalar *d_U,
200                                   CeedScalar &U) {
201   U = (tidx<P1D && tidy<P1D) ?
202       d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0;
203 }
204 
205 //------------------------------------------------------------------------------
206 // Write DoFs
207 //------------------------------------------------------------------------------
208 inline __device__ void writeDofs2d(const int elem, const int tidx,
209                                    const int tidy, const int comp,
210                                    const int nelem, const CeedScalar &r_V,
211                                    CeedScalar *d_V) {
212   if (tidx<P1D && tidy<P1D)
213     d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V;
214 }
215 
216 //------------------------------------------------------------------------------
217 // Read quadrature point data
218 //------------------------------------------------------------------------------
219 inline __device__ void readQuads2d(const int elem, const int tidx,
220                                    const int tidy, const int comp,
221                                    const int dim, const int nelem,
222                                    const CeedScalar *d_U, CeedScalar &U ) {
223   U = (tidx<Q1D && tidy<Q1D) ?
224       d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
225       dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0;
226 }
227 
228 //------------------------------------------------------------------------------
229 // Write quadrature point data
230 //------------------------------------------------------------------------------
231 inline __device__ void writeQuads2d(const int elem, const int tidx,
232                                     const int tidy, const int comp,
233                                     const int dim, const int nelem,
234                                     const CeedScalar &r_V, CeedScalar *d_V) {
235   if (tidx<Q1D && tidy<Q1D)
236     d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
237     dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V;
238 }
239 
240 //------------------------------------------------------------------------------
241 // 2D tensor contraction x
242 //------------------------------------------------------------------------------
243 inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
244                                    const int tidy, const int tidz,
245                                    const CeedScalar &U, const CeedScalar *B,
246                                    CeedScalar &V) {
247   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
248   __syncthreads();
249   V = 0.0;
250   if (tidx < Q1D)
251     for (int i = 0; i < P1D; ++i)
252       V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
253   __syncthreads();
254 }
255 
256 //------------------------------------------------------------------------------
257 // 2D tensor contraction y
258 //------------------------------------------------------------------------------
259 inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
260                                    const int tidy, const int tidz,
261                                    const CeedScalar &U, const CeedScalar *B,
262                                    CeedScalar &V) {
263   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
264   __syncthreads();
265   V = 0.0;
266   if (tidy < Q1D)
267     for (int i = 0; i < P1D; ++i)
268       V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
269   __syncthreads();
270 }
271 
272 //------------------------------------------------------------------------------
273 // 2D transpose tensor contraction y
274 //------------------------------------------------------------------------------
275 inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
276     const int tidy, const int tidz,
277     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
278   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
279   __syncthreads();
280   V = 0.0;
281   if (tidy < P1D)
282     for (int i = 0; i < Q1D; ++i)
283       V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
284   __syncthreads();
285 }
286 
287 //------------------------------------------------------------------------------
288 // 2D transpose tensor contraction x
289 //------------------------------------------------------------------------------
290 inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
291     const int tidy, const int tidz,
292     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
293   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
294   __syncthreads();
295   V = 0.0;
296   if (tidx < P1D)
297     for (int i = 0; i < Q1D; ++i)
298       V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
299   __syncthreads();
300 }
301 
302 //------------------------------------------------------------------------------
303 // 2D interpolate to quadrature points
304 //------------------------------------------------------------------------------
305 inline __device__ void interp2d(const CeedInt nelem, const int transpose,
306                                 const CeedScalar *c_B,
307                                 const CeedScalar *__restrict__ d_U,
308                                 CeedScalar *__restrict__ d_V,
309                                 CeedScalar *slice) {
310   CeedScalar r_V;
311   CeedScalar r_t;
312 
313   const int tidx = threadIdx.x;
314   const int tidy = threadIdx.y;
315   const int tidz = threadIdx.z;
316   const int blockElem = tidz/BASIS_NCOMP;
317   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
318   const int comp = tidz%BASIS_NCOMP;
319 
320   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
321        elem += gridDim.x*elemsPerBlock) {
322     const int comp = tidz%BASIS_NCOMP;
323     r_V = 0.0;
324     r_t = 0.0;
325     if (!transpose) {
326       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
327       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
328       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
329       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
330     } else {
331       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
332       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
333       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
334       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
335     }
336   }
337 }
338 
339 //------------------------------------------------------------------------------
340 // 2D derivatives at quadrature points
341 //------------------------------------------------------------------------------
342 inline __device__ void grad2d(const CeedInt nelem, const int transpose,
343                               const CeedScalar *c_B, const CeedScalar *c_G,
344                               const CeedScalar *__restrict__ d_U,
345                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
346   CeedScalar r_U;
347   CeedScalar r_V;
348   CeedScalar r_t;
349 
350   const int tidx = threadIdx.x;
351   const int tidy = threadIdx.y;
352   const int tidz = threadIdx.z;
353   const int blockElem = tidz/BASIS_NCOMP;
354   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
355   const int comp = tidz%BASIS_NCOMP;
356   int dim;
357 
358   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
359        elem += gridDim.x*elemsPerBlock) {
360     if (!transpose) {
361       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
362       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
363       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
364       dim = 0;
365       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
366       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
367       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
368       dim = 1;
369       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
370     } else {
371       dim = 0;
372       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
373       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
374       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
375       dim = 1;
376       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
377       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
378       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
379       r_V += r_U;
380       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
381     }
382   }
383 }
384 
385 //------------------------------------------------------------------------------
386 // 2D quadrature weights
387 //------------------------------------------------------------------------------
388 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
389                          CeedScalar *w) {
390   const int i = threadIdx.x;
391   const int j = threadIdx.y;
392   const CeedScalar weight = qweight1d[i]*qweight1d[j];
393   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
394        elem += gridDim.x*blockDim.z) {
395     const int ind = elem*Q1D*Q1D + i + j*Q1D;
396     w[ind] = weight;
397   }
398 }
399 
400 //------------------------------------------------------------------------------
401 // 3D
402 //------------------------------------------------------------------------------
403 
404 //------------------------------------------------------------------------------
405 // Read DoFs
406 //------------------------------------------------------------------------------
407 inline __device__ void readDofs3d(const int elem, const int tidx,
408                                   const int tidy, const int comp,
409                                   const int nelem, const CeedScalar *d_U,
410                                   CeedScalar *r_U) {
411   for (int i = 0; i < P1D; i++)
412     r_U[i] = (tidx < P1D && tidy < P1D) ?
413               d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
414                   comp*P1D*P1D*P1D*nelem] : 0.0;
415   for (int i = P1D; i < Q1D; i++)
416     r_U[i] = 0.0;
417 }
418 
419 //------------------------------------------------------------------------------
420 // Write DoFs
421 //------------------------------------------------------------------------------
422 inline __device__ void writeDofs3d(const int elem, const int tidx,
423                                    const int tidy, const int comp,
424                                    const int nelem, const CeedScalar *r_V,
425                                    CeedScalar *d_V) {
426   if (tidx < P1D && tidy < P1D) {
427     for (int i = 0; i < P1D; i++)
428       d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
429           comp*P1D*P1D*P1D*nelem] = r_V[i];
430   }
431 }
432 
433 //------------------------------------------------------------------------------
434 // Read quadrature point data
435 //------------------------------------------------------------------------------
436 inline __device__ void readQuads3d(const int elem, const int tidx,
437                                    const int tidy, const int comp,
438                                    const int dim, const int nelem,
439                                    const CeedScalar *d_U, CeedScalar *r_U) {
440   for (int i = 0; i < Q1D; i++)
441     r_U[i] = (tidx < Q1D && tidy < Q1D) ?
442               d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
443               comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0;
444   for (int i = Q1D; i < P1D; i++)
445     r_U[i] = 0.0;
446 }
447 
448 //------------------------------------------------------------------------------
449 // Write quadrature point data
450 //------------------------------------------------------------------------------
451 inline __device__ void writeQuads3d(const int elem, const int tidx,
452                                     const int tidy, const int comp,
453                                     const int dim, const int nelem,
454                                     const CeedScalar *r_V, CeedScalar *d_V) {
455   if (tidx < Q1D && tidy < Q1D) {
456     for (int i = 0; i < Q1D; i++)
457       d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
458           dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i];
459   }
460 }
461 
462 //------------------------------------------------------------------------------
463 // 3D tensor contract x
464 //------------------------------------------------------------------------------
465 inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
466                                    const int tidy, const int tidz,
467                                    const CeedScalar *U,
468                                    const CeedScalar *B,
469                                    CeedScalar *V) {
470   for (int k = 0; k < P1D; ++k) {
471     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
472     __syncthreads();
473     V[k] = 0.0;
474     if (tidx < Q1D && tidy < P1D)
475       for (int i = 0; i < P1D; ++i)
476         V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
477     __syncthreads();
478   }
479 }
480 
481 //------------------------------------------------------------------------------
482 // 3D tensor contract y
483 //------------------------------------------------------------------------------
484 inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
485                                    const int tidy, const int tidz,
486                                    const CeedScalar *U,
487                                    const CeedScalar *B,
488                                    CeedScalar *V) {
489   for (int k = 0; k < P1D; ++k) {
490     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
491     __syncthreads();
492     V[k] = 0.0;
493     if (tidx < Q1D && tidy < Q1D)
494       for (int i = 0; i < P1D; ++i)
495         V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
496     __syncthreads();
497   }
498 }
499 
500 //------------------------------------------------------------------------------
501 // 3D tensor contract z
502 //------------------------------------------------------------------------------
503 inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
504                                    const int tidy, const int tidz,
505                                    const CeedScalar *U,
506                                    const CeedScalar *B,
507                                    CeedScalar *V) {
508   for (int k = 0; k < Q1D; ++k) {
509     V[k] = 0.0;
510     if (tidx < Q1D && tidy < Q1D)
511       for (int i = 0; i < P1D; ++i)
512         V[k] += B[i + k*P1D] * U[i]; // Contract z direction
513   }
514   for (int k = Q1D; k < P1D; ++k)
515     V[k] = 0.0;
516 }
517 
518 //------------------------------------------------------------------------------
519 // 3D transpose tensor contract z
520 //------------------------------------------------------------------------------
521 inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
522                                             const int tidy, const int tidz,
523                                             const CeedScalar *U,
524                                             const CeedScalar *B,
525                                             CeedScalar *V) {
526   for (int k = 0; k < P1D; ++k) {
527     V[k] = 0.0;
528     if (tidx < Q1D && tidy < Q1D)
529       for (int i = 0; i < Q1D; ++i)
530         V[k] += B[k + i*P1D] * U[i]; // Contract z direction
531   }
532   for (int k = P1D; k < Q1D; ++k)
533     V[k] = 0.0;
534 }
535 
536 //------------------------------------------------------------------------------
537 // 3D transpose tensor contract y
538 //------------------------------------------------------------------------------
539 inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
540                                             const int tidy, const int tidz,
541                                             const CeedScalar *U,
542                                             const CeedScalar *B,
543                                             CeedScalar *V) {
544   for (int k = 0; k < P1D; ++k) {
545     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
546     __syncthreads();
547     V[k] = 0.0;
548     if (tidx < Q1D && tidy < P1D)
549       for (int i = 0; i < Q1D; ++i)
550         V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
551     __syncthreads();
552   }
553 }
554 
555 //------------------------------------------------------------------------------
556 // 3D transpose tensor contract x
557 //------------------------------------------------------------------------------
558 inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
559                                             const int tidy, const int tidz,
560                                             const CeedScalar *U,
561                                             const CeedScalar *B,
562                                             CeedScalar *V) {
563   for (int k = 0; k < P1D; ++k) {
564     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
565     __syncthreads();
566     V[k] = 0.0;
567     if (tidx < P1D && tidy < P1D)
568       for (int i = 0; i < Q1D; ++i)
569         V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
570     __syncthreads();
571   }
572 }
573 
574 //------------------------------------------------------------------------------
575 // 3D interpolate to quadrature points
576 //------------------------------------------------------------------------------
577 inline __device__ void interp3d(const CeedInt nelem, const int transpose,
578                                 const CeedScalar *c_B,
579                                 const CeedScalar *__restrict__ d_U,
580                                 CeedScalar *__restrict__ d_V,
581                                 CeedScalar *slice) {
582   CeedScalar r_V[T1D];
583   CeedScalar r_t[T1D];
584 
585   const int tidx = threadIdx.x;
586   const int tidy = threadIdx.y;
587   const int tidz = threadIdx.z;
588   const int blockElem = tidz/BASIS_NCOMP;
589   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
590   const int comp = tidz%BASIS_NCOMP;
591 
592   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
593        elem += gridDim.x*elemsPerBlock) {
594     for (int i = 0; i < T1D; ++i) {
595       r_V[i] = 0.0;
596       r_t[i] = 0.0;
597     }
598     if (!transpose) {
599       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
600       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
601       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
602       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
603       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
604     } else {
605       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
606       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
607       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
608       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
609       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
610     }
611   }
612 }
613 
614 //------------------------------------------------------------------------------
615 // 3D derivatives at quadrature points
616 //------------------------------------------------------------------------------
617 inline __device__ void grad3d(const CeedInt nelem, const int transpose,
618                               const CeedScalar *c_B, const CeedScalar *c_G,
619                               const CeedScalar *__restrict__ d_U,
620                               CeedScalar *__restrict__ d_V,
621                               CeedScalar *slice) {
622   // Use P1D for one of these
623   CeedScalar r_U[T1D];
624   CeedScalar r_V[T1D];
625   CeedScalar r_t[T1D];
626 
627   const int tidx = threadIdx.x;
628   const int tidy = threadIdx.y;
629   const int tidz = threadIdx.z;
630   const int blockElem = tidz/BASIS_NCOMP;
631   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
632   const int comp = tidz%BASIS_NCOMP;
633   int dim;
634 
635   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
636        elem += gridDim.x*elemsPerBlock) {
637     for (int i = 0; i < T1D; ++i) {
638       r_U[i] = 0.0;
639       r_V[i] = 0.0;
640       r_t[i] = 0.0;
641     }
642     if (!transpose) {
643       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
644       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
645       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
646       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
647       dim = 0;
648       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
649       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
650       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
651       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
652       dim = 1;
653       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
654       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
655       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
656       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
657       dim = 2;
658       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
659     } else {
660       dim = 0;
661       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
662       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
663       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
664       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
665       dim = 1;
666       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
667       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
668       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
669       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
670       add(r_V, r_t);
671       dim = 2;
672       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
673       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
674       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
675       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
676       add(r_V, r_t);
677       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
678     }
679   }
680 }
681 
682 //------------------------------------------------------------------------------
683 // 3D quadrature weights
684 //------------------------------------------------------------------------------
685 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
686                          CeedScalar *w) {
687   const int i = threadIdx.x;
688   const int j = threadIdx.y;
689   const int k = threadIdx.z;
690   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
691   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
692     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
693     w[ind] = weight;
694   }
695 }
696 
697 
698 //------------------------------------------------------------------------------
699 // Basis kernels
700 //------------------------------------------------------------------------------
701 
702 //------------------------------------------------------------------------------
703 // Interp kernel by dim
704 //------------------------------------------------------------------------------
705 extern "C" __global__ void interp(const CeedInt nelem, const int transpose,
706                                   const CeedScalar *c_B,
707                                   const CeedScalar *__restrict__ d_U,
708                                   CeedScalar *__restrict__ d_V) {
709   HIP_DYNAMIC_SHARED( double, slice)
710   if (BASIS_DIM == 1) {
711     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
712   } else if (BASIS_DIM == 2) {
713     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
714   } else if (BASIS_DIM == 3) {
715     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
716   }
717 }
718 
719 //------------------------------------------------------------------------------
720 // Grad kernel by dim
721 //------------------------------------------------------------------------------
722 extern "C" __global__ void grad(const CeedInt nelem, const int transpose,
723                                 const CeedScalar *c_B, const CeedScalar *c_G,
724                                 const CeedScalar *__restrict__ d_U,
725                                 CeedScalar *__restrict__ d_V) {
726   HIP_DYNAMIC_SHARED( double, slice)
727   if (BASIS_DIM == 1) {
728     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
729   } else if (BASIS_DIM == 2) {
730     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
731   } else if (BASIS_DIM == 3) {
732     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
733   }
734 }
735 
736 //------------------------------------------------------------------------------
737 // Weight kernels by dim
738 //------------------------------------------------------------------------------
739 extern "C" __global__ void weight(const CeedInt nelem,
740                                   const CeedScalar *__restrict__ qweight1d,
741                                   CeedScalar *__restrict__ v) {
742   if (BASIS_DIM == 1) {
743     weight1d(nelem, qweight1d, v);
744   } else if (BASIS_DIM == 2) {
745     weight2d(nelem, qweight1d, v);
746   } else if (BASIS_DIM == 3) {
747     weight3d(nelem, qweight1d, v);
748   }
749 }
750 
751 );
752 // *INDENT-ON*
753 
754 //------------------------------------------------------------------------------
755 // Device initalization
756 //------------------------------------------------------------------------------
757 int CeedHipInitInterp(CeedScalar *d_B, CeedInt P1d, CeedInt Q1d,
758                       CeedScalar **c_B);
759 int CeedHipInitInterpGrad(CeedScalar *d_B, CeedScalar *d_G, CeedInt P1d,
760                           CeedInt Q1d, CeedScalar **c_B_ptr,
761                           CeedScalar **c_G_ptr);
762 
763 //------------------------------------------------------------------------------
764 // Apply basis
765 //------------------------------------------------------------------------------
766 int CeedBasisApplyTensor_Hip_shared(CeedBasis basis, const CeedInt nelem,
767                                     CeedTransposeMode tmode,
768                                     CeedEvalMode emode, CeedVector u,
769                                     CeedVector v) {
770   int ierr;
771   Ceed ceed;
772   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
773   Ceed_Hip_shared *ceed_Hip;
774   CeedGetData(ceed, &ceed_Hip); CeedChk(ierr);
775   CeedBasis_Hip_shared *data;
776   CeedBasisGetData(basis, &data); CeedChk(ierr);
777   const CeedInt transpose = tmode == CEED_TRANSPOSE;
778   CeedInt dim, ncomp;
779   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
780   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
781 
782   // Read vectors
783   const CeedScalar *d_u;
784   CeedScalar *d_v;
785   if (emode != CEED_EVAL_WEIGHT) {
786     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
787   }
788   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
789 
790   // Clear v for transpose mode
791   if (tmode == CEED_TRANSPOSE) {
792     CeedInt length;
793     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
794     ierr = hipMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
795   }
796 
797   // Apply basis operation
798   switch (emode) {
799   case CEED_EVAL_INTERP: {
800     CeedInt P1d, Q1d;
801     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
802     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
803     CeedInt thread1d = CeedIntMax(Q1d, P1d);
804     ierr = CeedHipInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
805     CeedChk(ierr);
806     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
807                           &d_u, &d_v
808                          };
809     if (dim == 1) {
810       CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64;
811       elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1;
812       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
813                                              ? 1 : 0 );
814       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
815       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, 1,
816                                        elemsPerBlock, sharedMem,
817                                        interpargs); CeedChk(ierr);
818     } else if (dim == 2) {
819       const CeedInt optElems[7] = {0,32,8,6,4,2,6};
820       // elemsPerBlock must be at least 1
821       CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1);
822       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
823                                              ? 1 : 0 );
824       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
825       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d,
826                                        ncomp*elemsPerBlock, sharedMem,
827                                        interpargs); CeedChk(ierr);
828     } else if (dim == 3) {
829       CeedInt elemsPerBlock = 1;
830       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
831                                              ? 1 : 0 );
832       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
833       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d,
834                                        ncomp*elemsPerBlock, sharedMem,
835                                        interpargs); CeedChk(ierr);
836     }
837   } break;
838   case CEED_EVAL_GRAD: {
839     CeedInt P1d, Q1d;
840     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
841     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
842     CeedInt thread1d = CeedIntMax(Q1d, P1d);
843     ierr = CeedHipInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
844                                  Q1d, &data->c_B, &data->c_G);
845     CeedChk(ierr);
846     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
847                         &data->c_G, &d_u, &d_v
848                        };
849     if (dim == 1) {
850       CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64;
851       elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1;
852       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
853                                              ? 1 : 0 );
854       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
855       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, 1,
856                                        elemsPerBlock, sharedMem, gradargs);
857       CeedChk(ierr);
858     } else if (dim == 2) {
859       const CeedInt optElems[7] = {0,32,8,6,4,2,6};
860       // elemsPerBlock must be at least 1
861       CeedInt elemsPerBlock = CeedIntMax(thread1d<7?optElems[thread1d]/ncomp:1, 1);
862       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
863                                              ? 1 : 0 );
864       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
865       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d,
866                                        ncomp*elemsPerBlock, sharedMem,
867                                        gradargs); CeedChk(ierr);
868     } else if (dim == 3) {
869       CeedInt elemsPerBlock = 1;
870       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
871                                              ? 1 : 0 );
872       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
873       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d,
874                                        ncomp*elemsPerBlock, sharedMem,
875                                        gradargs); CeedChk(ierr);
876     }
877   } break;
878   case CEED_EVAL_WEIGHT: {
879     CeedInt Q1d;
880     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
881     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
882     if (dim == 1) {
883       const CeedInt optElems = 64/Q1d;
884       const CeedInt elemsPerBlock = optElems>0?optElems:1;
885       const CeedInt gridsize = nelem/elemsPerBlock + ( (
886                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
887       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d,
888                                  elemsPerBlock, 1, weightargs);
889       CeedChk(ierr);
890     } else if (dim == 2) {
891       const CeedInt optElems = 64/(Q1d*Q1d);
892       const CeedInt elemsPerBlock = optElems>0?optElems:1;
893       const CeedInt gridsize = nelem/elemsPerBlock + ( (
894                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
895       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d,
896                                  elemsPerBlock, weightargs);
897       CeedChk(ierr);
898     } else if (dim == 3) {
899       const CeedInt gridsize = nelem;
900       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
901                                  weightargs);
902       CeedChk(ierr);
903     }
904   } break;
905   // LCOV_EXCL_START
906   // Evaluate the divergence to/from the quadrature points
907   case CEED_EVAL_DIV:
908     return CeedError(ceed, 1, "CEED_EVAL_DIV not supported");
909   // Evaluate the curl to/from the quadrature points
910   case CEED_EVAL_CURL:
911     return CeedError(ceed, 1, "CEED_EVAL_CURL not supported");
912   // Take no action, BasisApply should not have been called
913   case CEED_EVAL_NONE:
914     return CeedError(ceed, 1,
915                      "CEED_EVAL_NONE does not make sense in this context");
916     // LCOV_EXCL_STOP
917   }
918 
919   // Restore vectors
920   if (emode != CEED_EVAL_WEIGHT) {
921     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
922   }
923   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
924   return 0;
925 }
926 
927 //------------------------------------------------------------------------------
928 // Destroy basis
929 //------------------------------------------------------------------------------
930 static int CeedBasisDestroy_Hip_shared(CeedBasis basis) {
931   int ierr;
932   Ceed ceed;
933   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
934 
935   CeedBasis_Hip_shared *data;
936   ierr = CeedBasisGetData(basis, &data); CeedChk(ierr);
937 
938   CeedChk_Hip(ceed, hipModuleUnload(data->module));
939 
940   ierr = hipFree(data->d_qweight1d); CeedChk_Hip(ceed, ierr);
941   ierr = hipFree(data->d_interp1d); CeedChk_Hip(ceed, ierr);
942   ierr = hipFree(data->d_grad1d); CeedChk_Hip(ceed, ierr);
943   ierr = hipFree(data->d_collograd1d); CeedChk_Hip(ceed, ierr);
944 
945   ierr = CeedFree(&data); CeedChk(ierr);
946 
947   return 0;
948 }
949 
950 //------------------------------------------------------------------------------
951 // Create tensor basis
952 //------------------------------------------------------------------------------
953 int CeedBasisCreateTensorH1_Hip_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
954                                        const CeedScalar *interp1d,
955                                        const CeedScalar *grad1d,
956                                        const CeedScalar *qref1d,
957                                        const CeedScalar *qweight1d,
958                                        CeedBasis basis) {
959   int ierr;
960   Ceed ceed;
961   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
962   CeedBasis_Hip_shared *data;
963   ierr = CeedCalloc(1, &data); CeedChk(ierr);
964 
965   // Copy basis data to GPU
966   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
967   ierr = hipMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Hip(ceed, ierr);
968   ierr = hipMemcpy(data->d_qweight1d, qweight1d, qBytes,
969                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
970 
971   const CeedInt iBytes = qBytes * P1d;
972   ierr = hipMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Hip(ceed, ierr);
973   ierr = hipMemcpy(data->d_interp1d, interp1d, iBytes,
974                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
975 
976   ierr = hipMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Hip(ceed, ierr);
977   ierr = hipMemcpy(data->d_grad1d, grad1d, iBytes,
978                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
979 
980   // Compute collocated gradient and copy to GPU
981   data->d_collograd1d = NULL;
982   if (dim == 3 && Q1d >= P1d) {
983     CeedScalar *collograd1d;
984     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChk(ierr);
985     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChk(ierr);
986     ierr = hipMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
987     CeedChk_Hip(ceed, ierr);
988     ierr = hipMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
989                      hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
990     ierr = CeedFree(&collograd1d); CeedChk(ierr);
991   }
992 
993   // Compile basis kernels
994   CeedInt ncomp;
995   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
996   ierr = CeedCompileHip(ceed, kernelsShared, &data->module, 8,
997                         "Q1D", Q1d,
998                         "P1D", P1d,
999                         "T1D", CeedIntMax(Q1d, P1d),
1000                         "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
1001                             Q1d : P1d, dim),
1002                         "BASIS_DIM", dim,
1003                         "BASIS_NCOMP", ncomp,
1004                         "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
1005                         "BASIS_NQPT", CeedIntPow(Q1d, dim)
1006                        ); CeedChk(ierr);
1007   ierr = CeedGetKernelHip(ceed, data->module, "interp", &data->interp);
1008   CeedChk(ierr);
1009   ierr = CeedGetKernelHip(ceed, data->module, "grad", &data->grad);
1010   CeedChk(ierr);
1011   ierr = CeedGetKernelHip(ceed, data->module, "weight", &data->weight);
1012   CeedChk(ierr);
1013 
1014   ierr = CeedBasisSetData(basis, data); CeedChk(ierr);
1015 
1016   // Register backend functions
1017   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
1018                                 CeedBasisApplyTensor_Hip_shared);
1019   CeedChk(ierr);
1020   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
1021                                 CeedBasisDestroy_Hip_shared); CeedChk(ierr);
1022   return 0;
1023 }
1024 //------------------------------------------------------------------------------
1025