xref: /libCEED/backends/hip-shared/ceed-hip-shared-basis.c (revision b425b72ce6207063c385cb41f0ea18e8c839ca9f)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed/ceed.h>
18 #include <ceed/backend.h>
19 #include <hip/hip_runtime.h>
20 #include <stddef.h>
21 #include "ceed-hip-shared.h"
22 #include "../hip/ceed-hip.h"
23 #include "../hip/ceed-hip-compile.h"
24 
25 //------------------------------------------------------------------------------
26 // Shared mem kernels
27 //------------------------------------------------------------------------------
28 // *INDENT-OFF*
29 static const char *kernelsShared = QUOTE(
30 
31 //------------------------------------------------------------------------------
32 // Sum input into output
33 //------------------------------------------------------------------------------
34 inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
35   for (int i = 0; i < P1D; i++)
36     r_V[i] += r_U[i];
37 }
38 
39 //------------------------------------------------------------------------------
40 // Load matrices for basis actions
41 //------------------------------------------------------------------------------
42 inline __device__ void loadMatrix(const CeedScalar* d_B, CeedScalar* B) {
43   CeedInt tid = threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.y*blockDim.x;
44   for (CeedInt i = tid; i < P1D*Q1D; i += blockDim.x*blockDim.y*blockDim.z)
45     B[i] = d_B[i];
46 }
47 
48 //------------------------------------------------------------------------------
49 // 1D
50 //------------------------------------------------------------------------------
51 
52 //------------------------------------------------------------------------------
53 // Read DoFs
54 //------------------------------------------------------------------------------
55 inline __device__ void readDofs1d(const int elem, const int tidx,
56                                   const int tidy, const int tidz,const int comp,
57                                   const int nelem, const CeedScalar *d_U,
58                                   CeedScalar *slice) {
59   for (int i = 0; i < P1D; i++)
60     slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem];
61   for (int i = P1D; i < Q1D; i++)
62     slice[i + tidz*T1D] = 0.0;
63 }
64 
65 //------------------------------------------------------------------------------
66 // Write DoFs
67 //------------------------------------------------------------------------------
68 inline __device__ void writeDofs1d(const int elem, const int tidx,
69                                    const int tidy, const int comp,
70                                    const int nelem, const CeedScalar &r_V,
71                                    CeedScalar *d_V) {
72   if (tidx<P1D)
73     d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V;
74 }
75 
76 //------------------------------------------------------------------------------
77 // Read quadrature point data
78 //------------------------------------------------------------------------------
79 inline __device__ void readQuads1d(const int elem, const int tidx,
80                                    const int tidy, const int tidz, const int comp,
81                                    const int dim, const int nelem,
82                                    const CeedScalar *d_U, CeedScalar *slice) {
83   for (int i = 0; i < Q1D; i++)
84     slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
85                             dim*BASIS_NCOMP*nelem*Q1D];
86   for (int i = Q1D; i < P1D; i++)
87     slice[i + tidz*T1D] = 0.0;
88 }
89 
90 //------------------------------------------------------------------------------
91 // Write quadrature point data
92 //------------------------------------------------------------------------------
93 inline __device__ void writeQuads1d(const int elem, const int tidx,
94                                     const int tidy, const int comp,
95                                     const int dim, const int nelem,
96                                     const CeedScalar &r_V, CeedScalar *d_V) {
97   if (tidx<Q1D)
98     d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
99 }
100 
101 //------------------------------------------------------------------------------
102 // 1D tensor contraction
103 //------------------------------------------------------------------------------
104 inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
105                                    const int tidy, const int tidz,
106                                    const CeedScalar &U, const CeedScalar *B,
107                                    CeedScalar &V) {
108   V = 0.0;
109   for (int i = 0; i < P1D; ++i)
110     V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction
111 }
112 
113 //------------------------------------------------------------------------------
114 // 1D transpose tensor contraction
115 //------------------------------------------------------------------------------
116 inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
117     const int tidy, const int tidz,
118     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
119   V = 0.0;
120   for (int i = 0; i < Q1D; ++i)
121     V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction
122 }
123 
124 //------------------------------------------------------------------------------
125 // 1D interpolate to quadrature points
126 //------------------------------------------------------------------------------
127 inline __device__ void interp1d(const CeedInt nelem, const int transpose,
128                                 const CeedScalar *s_B,
129                                 const CeedScalar *__restrict__ d_U,
130                                 CeedScalar *__restrict__ d_V,
131                                 CeedScalar *slice) {
132   CeedScalar r_V;
133   CeedScalar r_t;
134 
135   const int tidx = threadIdx.x;
136   const int tidy = threadIdx.y;
137   const int tidz = threadIdx.z;
138 
139   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
140        elem += gridDim.x*blockDim.z) {
141     for (int comp = 0; comp < BASIS_NCOMP; comp++) {
142       if (!transpose) {
143         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
144         ContractX1d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
145         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
146       } else {
147         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
148         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
149         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
150       }
151     }
152   }
153 }
154 
155 //------------------------------------------------------------------------------
156 // 1D derivatives at quadrature points
157 //------------------------------------------------------------------------------
158 inline __device__ void grad1d(const CeedInt nelem, const int transpose,
159                               const CeedScalar *s_B, const CeedScalar *s_G,
160                               const CeedScalar *__restrict__ d_U,
161                               CeedScalar *__restrict__ d_V,
162                               CeedScalar *slice) {
163   CeedScalar r_U;
164   CeedScalar r_V;
165 
166   const int tidx = threadIdx.x;
167   const int tidy = threadIdx.y;
168   const int tidz = threadIdx.z;
169   int dim;
170 
171   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
172        elem += gridDim.x*blockDim.z) {
173     for(int comp = 0; comp < BASIS_NCOMP; comp++) {
174       if (!transpose) {
175         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
176         ContractX1d(slice, tidx, tidy, tidz, r_U, s_G, r_V);
177         dim = 0;
178         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
179       } else {
180         dim = 0;
181         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
182         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, s_G, r_V);
183         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
184       }
185     }
186   }
187 }
188 
189 //------------------------------------------------------------------------------
190 // 1D Quadrature weights
191 //------------------------------------------------------------------------------
192 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
193                          CeedScalar *w) {
194   const int tid = threadIdx.x;
195   const CeedScalar weight = qweight1d[tid];
196   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
197        elem += gridDim.x*blockDim.y) {
198     const int ind = elem*Q1D + tid;
199     w[ind] = weight;
200   }
201 }
202 
203 //------------------------------------------------------------------------------
204 // 2D
205 //------------------------------------------------------------------------------
206 
207 //------------------------------------------------------------------------------
208 // Read DoFs
209 //------------------------------------------------------------------------------
210 inline __device__ void readDofs2d(const int elem, const int tidx,
211                                   const int tidy, const int comp,
212                                   const int nelem, const CeedScalar *d_U,
213                                   CeedScalar &U) {
214   U = (tidx<P1D && tidy<P1D) ?
215       d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0;
216 }
217 
218 //------------------------------------------------------------------------------
219 // Write DoFs
220 //------------------------------------------------------------------------------
221 inline __device__ void writeDofs2d(const int elem, const int tidx,
222                                    const int tidy, const int comp,
223                                    const int nelem, const CeedScalar &r_V,
224                                    CeedScalar *d_V) {
225   if (tidx<P1D && tidy<P1D)
226     d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V;
227 }
228 
229 //------------------------------------------------------------------------------
230 // Read quadrature point data
231 //------------------------------------------------------------------------------
232 inline __device__ void readQuads2d(const int elem, const int tidx,
233                                    const int tidy, const int comp,
234                                    const int dim, const int nelem,
235                                    const CeedScalar *d_U, CeedScalar &U ) {
236   U = (tidx<Q1D && tidy<Q1D) ?
237       d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
238       dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0;
239 }
240 
241 //------------------------------------------------------------------------------
242 // Write quadrature point data
243 //------------------------------------------------------------------------------
244 inline __device__ void writeQuads2d(const int elem, const int tidx,
245                                     const int tidy, const int comp,
246                                     const int dim, const int nelem,
247                                     const CeedScalar &r_V, CeedScalar *d_V) {
248   if (tidx<Q1D && tidy<Q1D)
249     d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
250     dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V;
251 }
252 
253 //------------------------------------------------------------------------------
254 // 2D tensor contraction x
255 //------------------------------------------------------------------------------
256 inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
257                                    const int tidy, const int tidz,
258                                    const CeedScalar &U, const CeedScalar *B,
259                                    CeedScalar &V) {
260   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
261   __syncthreads();
262   V = 0.0;
263   if (tidx < Q1D)
264     for (int i = 0; i < P1D; ++i)
265       V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
266   __syncthreads();
267 }
268 
269 //------------------------------------------------------------------------------
270 // 2D tensor contraction y
271 //------------------------------------------------------------------------------
272 inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
273                                    const int tidy, const int tidz,
274                                    const CeedScalar &U, const CeedScalar *B,
275                                    CeedScalar &V) {
276   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
277   __syncthreads();
278   V = 0.0;
279   if (tidy < Q1D)
280     for (int i = 0; i < P1D; ++i)
281       V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
282   __syncthreads();
283 }
284 
285 //------------------------------------------------------------------------------
286 // 2D transpose tensor contraction y
287 //------------------------------------------------------------------------------
288 inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
289     const int tidy, const int tidz,
290     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
291   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
292   __syncthreads();
293   V = 0.0;
294   if (tidy < P1D)
295     for (int i = 0; i < Q1D; ++i)
296       V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
297   __syncthreads();
298 }
299 
300 //------------------------------------------------------------------------------
301 // 2D transpose tensor contraction x
302 //------------------------------------------------------------------------------
303 inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
304     const int tidy, const int tidz,
305     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
306   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
307   __syncthreads();
308   V = 0.0;
309   if (tidx < P1D)
310     for (int i = 0; i < Q1D; ++i)
311       V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
312   __syncthreads();
313 }
314 
315 //------------------------------------------------------------------------------
316 // 2D interpolate to quadrature points
317 //------------------------------------------------------------------------------
318 inline __device__ void interp2d(const CeedInt nelem, const int transpose,
319                                 const CeedScalar *s_B,
320                                 const CeedScalar *__restrict__ d_U,
321                                 CeedScalar *__restrict__ d_V,
322                                 CeedScalar *slice) {
323   CeedScalar r_V;
324   CeedScalar r_t;
325 
326   const int tidx = threadIdx.x;
327   const int tidy = threadIdx.y;
328   const int tidz = threadIdx.z;
329   const int blockElem = tidz/BASIS_NCOMP;
330   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
331   const int comp = tidz%BASIS_NCOMP;
332 
333   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
334        elem += gridDim.x*elemsPerBlock) {
335     const int comp = tidz%BASIS_NCOMP;
336     r_V = 0.0;
337     r_t = 0.0;
338     if (!transpose) {
339       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
340       ContractX2d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
341       ContractY2d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
342       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
343     } else {
344       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
345       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
346       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
347       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
348     }
349   }
350 }
351 
352 //------------------------------------------------------------------------------
353 // 2D derivatives at quadrature points
354 //------------------------------------------------------------------------------
355 inline __device__ void grad2d(const CeedInt nelem, const int transpose,
356                               const CeedScalar *s_B, const CeedScalar *s_G,
357                               const CeedScalar *__restrict__ d_U,
358                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
359   CeedScalar r_U;
360   CeedScalar r_V;
361   CeedScalar r_t;
362 
363   const int tidx = threadIdx.x;
364   const int tidy = threadIdx.y;
365   const int tidz = threadIdx.z;
366   const int blockElem = tidz/BASIS_NCOMP;
367   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
368   const int comp = tidz%BASIS_NCOMP;
369   int dim;
370 
371   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
372        elem += gridDim.x*elemsPerBlock) {
373     if (!transpose) {
374       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
375       ContractX2d(slice, tidx, tidy, tidz, r_U, s_G, r_t);
376       ContractY2d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
377       dim = 0;
378       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
379       ContractX2d(slice, tidx, tidy, tidz, r_U, s_B, r_t);
380       ContractY2d(slice, tidx, tidy, tidz, r_t, s_G, r_V);
381       dim = 1;
382       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
383     } else {
384       dim = 0;
385       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
386       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, s_B, r_t);
387       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, s_G, r_V);
388       dim = 1;
389       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
390       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, s_G, r_t);
391       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, s_B, r_U);
392       r_V += r_U;
393       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
394     }
395   }
396 }
397 
398 //------------------------------------------------------------------------------
399 // 2D quadrature weights
400 //------------------------------------------------------------------------------
401 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
402                          CeedScalar *w) {
403   const int i = threadIdx.x;
404   const int j = threadIdx.y;
405   const CeedScalar weight = qweight1d[i]*qweight1d[j];
406   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
407        elem += gridDim.x*blockDim.z) {
408     const int ind = elem*Q1D*Q1D + i + j*Q1D;
409     w[ind] = weight;
410   }
411 }
412 
413 //------------------------------------------------------------------------------
414 // 3D
415 //------------------------------------------------------------------------------
416 
417 //------------------------------------------------------------------------------
418 // Read DoFs
419 //------------------------------------------------------------------------------
420 inline __device__ void readDofs3d(const int elem, const int tidx,
421                                   const int tidy, const int comp,
422                                   const int nelem, const CeedScalar *d_U,
423                                   CeedScalar *r_U) {
424   for (int i = 0; i < P1D; i++)
425     r_U[i] = (tidx < P1D && tidy < P1D) ?
426               d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
427                   comp*P1D*P1D*P1D*nelem] : 0.0;
428   for (int i = P1D; i < Q1D; i++)
429     r_U[i] = 0.0;
430 }
431 
432 //------------------------------------------------------------------------------
433 // Write DoFs
434 //------------------------------------------------------------------------------
435 inline __device__ void writeDofs3d(const int elem, const int tidx,
436                                    const int tidy, const int comp,
437                                    const int nelem, const CeedScalar *r_V,
438                                    CeedScalar *d_V) {
439   if (tidx < P1D && tidy < P1D) {
440     for (int i = 0; i < P1D; i++)
441       d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
442           comp*P1D*P1D*P1D*nelem] = r_V[i];
443   }
444 }
445 
446 //------------------------------------------------------------------------------
447 // Read quadrature point data
448 //------------------------------------------------------------------------------
449 inline __device__ void readQuads3d(const int elem, const int tidx,
450                                    const int tidy, const int comp,
451                                    const int dim, const int nelem,
452                                    const CeedScalar *d_U, CeedScalar *r_U) {
453   for (int i = 0; i < Q1D; i++)
454     r_U[i] = (tidx < Q1D && tidy < Q1D) ?
455               d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
456               comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0;
457   for (int i = Q1D; i < P1D; i++)
458     r_U[i] = 0.0;
459 }
460 
461 //------------------------------------------------------------------------------
462 // Write quadrature point data
463 //------------------------------------------------------------------------------
464 inline __device__ void writeQuads3d(const int elem, const int tidx,
465                                     const int tidy, const int comp,
466                                     const int dim, const int nelem,
467                                     const CeedScalar *r_V, CeedScalar *d_V) {
468   if (tidx < Q1D && tidy < Q1D) {
469     for (int i = 0; i < Q1D; i++)
470       d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
471           dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i];
472   }
473 }
474 
475 //------------------------------------------------------------------------------
476 // 3D tensor contract x
477 //------------------------------------------------------------------------------
478 inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
479                                    const int tidy, const int tidz,
480                                    const CeedScalar *U,
481                                    const CeedScalar *B,
482                                    CeedScalar *V) {
483   for (int k = 0; k < P1D; ++k) {
484     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
485     __syncthreads();
486     V[k] = 0.0;
487     if (tidx < Q1D && tidy < P1D)
488       for (int i = 0; i < P1D; ++i)
489         V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
490     __syncthreads();
491   }
492 }
493 
494 //------------------------------------------------------------------------------
495 // 3D tensor contract y
496 //------------------------------------------------------------------------------
497 inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
498                                    const int tidy, const int tidz,
499                                    const CeedScalar *U,
500                                    const CeedScalar *B,
501                                    CeedScalar *V) {
502   for (int k = 0; k < P1D; ++k) {
503     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
504     __syncthreads();
505     V[k] = 0.0;
506     if (tidx < Q1D && tidy < Q1D)
507       for (int i = 0; i < P1D; ++i)
508         V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
509     __syncthreads();
510   }
511 }
512 
513 //------------------------------------------------------------------------------
514 // 3D tensor contract z
515 //------------------------------------------------------------------------------
516 inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
517                                    const int tidy, const int tidz,
518                                    const CeedScalar *U,
519                                    const CeedScalar *B,
520                                    CeedScalar *V) {
521   for (int k = 0; k < Q1D; ++k) {
522     V[k] = 0.0;
523     if (tidx < Q1D && tidy < Q1D)
524       for (int i = 0; i < P1D; ++i)
525         V[k] += B[i + k*P1D] * U[i]; // Contract z direction
526   }
527   for (int k = Q1D; k < P1D; ++k)
528     V[k] = 0.0;
529 }
530 
531 //------------------------------------------------------------------------------
532 // 3D transpose tensor contract z
533 //------------------------------------------------------------------------------
534 inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
535                                             const int tidy, const int tidz,
536                                             const CeedScalar *U,
537                                             const CeedScalar *B,
538                                             CeedScalar *V) {
539   for (int k = 0; k < P1D; ++k) {
540     V[k] = 0.0;
541     if (tidx < Q1D && tidy < Q1D)
542       for (int i = 0; i < Q1D; ++i)
543         V[k] += B[k + i*P1D] * U[i]; // Contract z direction
544   }
545   for (int k = P1D; k < Q1D; ++k)
546     V[k] = 0.0;
547 }
548 
549 //------------------------------------------------------------------------------
550 // 3D transpose tensor contract y
551 //------------------------------------------------------------------------------
552 inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
553                                             const int tidy, const int tidz,
554                                             const CeedScalar *U,
555                                             const CeedScalar *B,
556                                             CeedScalar *V) {
557   for (int k = 0; k < P1D; ++k) {
558     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
559     __syncthreads();
560     V[k] = 0.0;
561     if (tidx < Q1D && tidy < P1D)
562       for (int i = 0; i < Q1D; ++i)
563         V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
564     __syncthreads();
565   }
566 }
567 
568 //------------------------------------------------------------------------------
569 // 3D transpose tensor contract x
570 //------------------------------------------------------------------------------
571 inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
572                                             const int tidy, const int tidz,
573                                             const CeedScalar *U,
574                                             const CeedScalar *B,
575                                             CeedScalar *V) {
576   for (int k = 0; k < P1D; ++k) {
577     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
578     __syncthreads();
579     V[k] = 0.0;
580     if (tidx < P1D && tidy < P1D)
581       for (int i = 0; i < Q1D; ++i)
582         V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
583     __syncthreads();
584   }
585 }
586 
587 //------------------------------------------------------------------------------
588 // 3D interpolate to quadrature points
589 //------------------------------------------------------------------------------
590 inline __device__ void interp3d(const CeedInt nelem, const int transpose,
591                                 const CeedScalar *s_B,
592                                 const CeedScalar *__restrict__ d_U,
593                                 CeedScalar *__restrict__ d_V,
594                                 CeedScalar *slice) {
595   CeedScalar r_V[T1D];
596   CeedScalar r_t[T1D];
597 
598   const int tidx = threadIdx.x;
599   const int tidy = threadIdx.y;
600   const int tidz = threadIdx.z;
601   const int blockElem = tidz/BASIS_NCOMP;
602   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
603   const int comp = tidz%BASIS_NCOMP;
604 
605   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
606        elem += gridDim.x*elemsPerBlock) {
607     for (int i = 0; i < T1D; ++i) {
608       r_V[i] = 0.0;
609       r_t[i] = 0.0;
610     }
611     if (!transpose) {
612       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
613       ContractX3d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
614       ContractY3d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
615       ContractZ3d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
616       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
617     } else {
618       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
619       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
620       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
621       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
622       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
623     }
624   }
625 }
626 
627 //------------------------------------------------------------------------------
628 // 3D derivatives at quadrature points
629 //------------------------------------------------------------------------------
630 inline __device__ void grad3d(const CeedInt nelem, const int transpose,
631                               const CeedScalar *s_B, const CeedScalar *s_G,
632                               const CeedScalar *__restrict__ d_U,
633                               CeedScalar *__restrict__ d_V,
634                               CeedScalar *slice) {
635   // Use P1D for one of these
636   CeedScalar r_U[T1D];
637   CeedScalar r_V[T1D];
638   CeedScalar r_t[T1D];
639 
640   const int tidx = threadIdx.x;
641   const int tidy = threadIdx.y;
642   const int tidz = threadIdx.z;
643   const int blockElem = tidz/BASIS_NCOMP;
644   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
645   const int comp = tidz%BASIS_NCOMP;
646   int dim;
647 
648   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
649        elem += gridDim.x*elemsPerBlock) {
650     for (int i = 0; i < T1D; ++i) {
651       r_U[i] = 0.0;
652       r_V[i] = 0.0;
653       r_t[i] = 0.0;
654     }
655     if (!transpose) {
656       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
657       ContractX3d(slice, tidx, tidy, tidz, r_U, s_G, r_V);
658       ContractY3d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
659       ContractZ3d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
660       dim = 0;
661       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
662       ContractX3d(slice, tidx, tidy, tidz, r_U, s_B, r_V);
663       ContractY3d(slice, tidx, tidy, tidz, r_V, s_G, r_t);
664       ContractZ3d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
665       dim = 1;
666       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
667       ContractX3d(slice, tidx, tidy, tidz, r_U, s_B, r_V);
668       ContractY3d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
669       ContractZ3d(slice, tidx, tidy, tidz, r_t, s_G, r_V);
670       dim = 2;
671       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
672     } else {
673       dim = 0;
674       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
675       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, s_B, r_t);
676       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_B, r_U);
677       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, s_G, r_V);
678       dim = 1;
679       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
680       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, s_B, r_t);
681       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_G, r_U);
682       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, s_B, r_t);
683       add(r_V, r_t);
684       dim = 2;
685       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
686       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, s_G, r_t);
687       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_B, r_U);
688       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, s_B, r_t);
689       add(r_V, r_t);
690       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
691     }
692   }
693 }
694 
695 //------------------------------------------------------------------------------
696 // 3D quadrature weights
697 //------------------------------------------------------------------------------
698 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
699                          CeedScalar *w) {
700   const int i = threadIdx.x;
701   const int j = threadIdx.y;
702   const int k = threadIdx.z;
703   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
704   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
705     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
706     w[ind] = weight;
707   }
708 }
709 
710 //------------------------------------------------------------------------------
711 // Basis kernels
712 //------------------------------------------------------------------------------
713 
714 //------------------------------------------------------------------------------
715 // Interp kernel by dim
716 //------------------------------------------------------------------------------
717 extern "C" __launch_bounds__(INTERP_BLKSIZE) __global__ void interp(
718                                   const CeedInt nelem, const int transpose,
719                                   CeedScalar *d_interp1d,
720                                   const CeedScalar *__restrict__ d_U,
721                                   CeedScalar *__restrict__ d_V) {
722 
723   HIP_DYNAMIC_SHARED( CeedScalar, slice)
724   // load interp1d into shared memory
725   __shared__ CeedScalar s_B[P1D*Q1D];
726   loadMatrix(d_interp1d, s_B);
727   __syncthreads();
728 
729   if (BASIS_DIM == 1) {
730     interp1d(nelem, transpose, s_B, d_U, d_V, slice);
731   } else if (BASIS_DIM == 2) {
732     interp2d(nelem, transpose, s_B, d_U, d_V, slice);
733   } else if (BASIS_DIM == 3) {
734     interp3d(nelem, transpose, s_B, d_U, d_V, slice);
735   }
736 }
737 
738 //------------------------------------------------------------------------------
739 // Grad kernel by dim
740 //------------------------------------------------------------------------------
741 extern "C" __launch_bounds__(GRAD_BLKSIZE) __global__ void grad(const CeedInt nelem,
742                                 const int transpose,
743                                 CeedScalar *d_interp1d, CeedScalar *d_grad1d,
744                                 const CeedScalar *__restrict__ d_U,
745                                 CeedScalar *__restrict__ d_V) {
746   HIP_DYNAMIC_SHARED( CeedScalar, slice)
747   // load interp1d and grad1d into shared memory
748   __shared__ CeedScalar s_B[P1D*Q1D];
749   loadMatrix(d_interp1d, s_B);
750   __shared__ CeedScalar s_G[P1D*Q1D];
751   loadMatrix(d_grad1d, s_G);
752   __syncthreads();
753 
754   if (BASIS_DIM == 1) {
755     grad1d(nelem, transpose, s_B, s_G, d_U, d_V, slice);
756   } else if (BASIS_DIM == 2) {
757     grad2d(nelem, transpose, s_B, s_G, d_U, d_V, slice);
758   } else if (BASIS_DIM == 3) {
759     grad3d(nelem, transpose, s_B, s_G, d_U, d_V, slice);
760   }
761 }
762 
763 //------------------------------------------------------------------------------
764 // Weight kernels by dim
765 //------------------------------------------------------------------------------
766 extern "C" __launch_bounds__(WEIGHT_BLKSIZE) __global__ void weight(const CeedInt nelem,
767                                   const CeedScalar *__restrict__ qweight1d,
768                                   CeedScalar *__restrict__ v) {
769   if (BASIS_DIM == 1) {
770     weight1d(nelem, qweight1d, v);
771   } else if (BASIS_DIM == 2) {
772     weight2d(nelem, qweight1d, v);
773   } else if (BASIS_DIM == 3) {
774     weight3d(nelem, qweight1d, v);
775   }
776 }
777 
778 );
779 // *INDENT-ON*
780 
781 //------------------------------------------------------------------------------
782 // Compute a block size based on required minimum threads
783 //------------------------------------------------------------------------------
784 static CeedInt ComputeBlockSizeFromRequirement(const CeedInt required) {
785   CeedInt maxSize = 1024;    // Max total threads per block
786   CeedInt currentSize = 64;  // Start with one group
787 
788   while(currentSize < maxSize) {
789     if (currentSize > required)
790       break;
791     else
792       currentSize = currentSize * 2;
793   }
794   return currentSize;
795 }
796 
797 //------------------------------------------------------------------------------
798 // Compute required thread block sizes for basis kernels given P, Q, dim, and
799 // ncomp
800 //------------------------------------------------------------------------------
801 static int ComputeBasisThreadBlockSizes(const CeedInt dim, const CeedInt P1d,
802                                         const CeedInt Q1d,
803                                         const CeedInt ncomp, CeedInt *blksizes) {
804 
805   // Note that this will use the same block sizes for all dimensions when compiling,
806   // but as each basis object is defined for a particular dimension, we will never
807   // call any kernels except the ones for the dimension for which we have computed the
808   // block sizes.
809   const CeedInt thread1d = CeedIntMax(P1d, Q1d);
810   switch (dim) {
811   case 1: {
812     // Interp kernels:
813     blksizes[0] = 256;
814 
815     // Grad kernels:
816     blksizes[1] = 256;
817 
818     // Weight kernels:
819     blksizes[2] = 256;
820 
821   } break;
822   case 2: {
823     // Interp kernels:
824     CeedInt required = thread1d * thread1d * ncomp;
825     blksizes[0]  = ComputeBlockSizeFromRequirement(required);
826 
827     // Grad kernels: currently use same required minimum threads
828     blksizes[1]  = ComputeBlockSizeFromRequirement(required);
829 
830     // Weight kernels:
831     required = CeedIntMax(64, Q1d * Q1d);
832     blksizes[2]  = ComputeBlockSizeFromRequirement(required);
833 
834   } break;
835   case 3: {
836     // Interp kernels:
837     CeedInt required = thread1d * thread1d * ncomp;
838     blksizes[0]  = ComputeBlockSizeFromRequirement(required);
839 
840     // Grad kernels: currently use same required minimum threads
841     blksizes[1]  = ComputeBlockSizeFromRequirement(required);
842 
843     // Weight kernels:
844     required = Q1d * Q1d * Q1d;
845     blksizes[2]  = ComputeBlockSizeFromRequirement(required);
846   }
847   }
848 
849   return CEED_ERROR_SUCCESS;
850 }
851 
852 //------------------------------------------------------------------------------
853 // Apply basis
854 //------------------------------------------------------------------------------
855 int CeedBasisApplyTensor_Hip_shared(CeedBasis basis, const CeedInt nelem,
856                                     CeedTransposeMode tmode,
857                                     CeedEvalMode emode, CeedVector u,
858                                     CeedVector v) {
859   int ierr;
860   Ceed ceed;
861   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
862   Ceed_Hip *ceed_Hip;
863   CeedGetData(ceed, &ceed_Hip); CeedChkBackend(ierr);
864   CeedBasis_Hip_shared *data;
865   CeedBasisGetData(basis, &data); CeedChkBackend(ierr);
866   const CeedInt transpose = tmode == CEED_TRANSPOSE;
867   CeedInt dim, ncomp;
868   ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr);
869   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr);
870 
871   // Read vectors
872   const CeedScalar *d_u;
873   CeedScalar *d_v;
874   if (emode != CEED_EVAL_WEIGHT) {
875     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChkBackend(ierr);
876   }
877   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr);
878 
879   // Clear v for transpose mode
880   if (tmode == CEED_TRANSPOSE) {
881     CeedInt length;
882     ierr = CeedVectorGetLength(v, &length); CeedChkBackend(ierr);
883     ierr = hipMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChkBackend(ierr);
884   }
885 
886   // Apply basis operation
887   switch (emode) {
888   case CEED_EVAL_INTERP: {
889     CeedInt P1d, Q1d;
890     CeedInt blksize = data->blksizes[0];
891     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr);
892     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr);
893     CeedInt thread1d = CeedIntMax(Q1d, P1d);
894     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->d_interp1d,
895                           &d_u, &d_v
896                          };
897     if (dim == 1) {
898       CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64;
899       elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1;
900       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
901                                              ? 1 : 0 );
902       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
903       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, 1,
904                                        elemsPerBlock, sharedMem,
905                                        interpargs); CeedChkBackend(ierr);
906     } else if (dim == 2) {
907       // Check if required threads is small enough to do multiple elems
908       const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1);
909       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
910                                              ? 1 : 0 );
911       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
912       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d,
913                                        ncomp*elemsPerBlock, sharedMem,
914                                        interpargs); CeedChkBackend(ierr);
915     } else if (dim == 3) {
916       CeedInt elemsPerBlock = 1;
917       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
918                                              ? 1 : 0 );
919       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
920       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d,
921                                        ncomp*elemsPerBlock, sharedMem,
922                                        interpargs); CeedChkBackend(ierr);
923     }
924   } break;
925   case CEED_EVAL_GRAD: {
926     CeedInt P1d, Q1d;
927     CeedInt blksize = data->blksizes[1];
928     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr);
929     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr);
930     CeedInt thread1d = CeedIntMax(Q1d, P1d);
931     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->d_interp1d,
932                         &data->d_grad1d, &d_u, &d_v
933                        };
934     if (dim == 1) {
935       CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64;
936       elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1;
937       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
938                                              ? 1 : 0 );
939       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
940       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, 1,
941                                        elemsPerBlock, sharedMem, gradargs);
942       CeedChkBackend(ierr);
943     } else if (dim == 2) {
944       // Check if required threads is small enough to do multiple elems
945       const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1);
946       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
947                                              ? 1 : 0 );
948       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
949       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d,
950                                        ncomp*elemsPerBlock, sharedMem,
951                                        gradargs); CeedChkBackend(ierr);
952     } else if (dim == 3) {
953       CeedInt elemsPerBlock = 1;
954       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
955                                              ? 1 : 0 );
956       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
957       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d,
958                                        ncomp*elemsPerBlock, sharedMem,
959                                        gradargs); CeedChkBackend(ierr);
960     }
961   } break;
962   case CEED_EVAL_WEIGHT: {
963     CeedInt Q1d;
964     CeedInt blksize = data->blksizes[2];
965     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr);
966     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
967     if (dim == 1) {
968       const CeedInt optElems = blksize/Q1d;
969       const CeedInt elemsPerBlock = optElems>0?optElems:1;
970       const CeedInt gridsize = nelem/elemsPerBlock + ( (
971                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
972       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d,
973                                  elemsPerBlock, 1, weightargs);
974       CeedChkBackend(ierr);
975     } else if (dim == 2) {
976       const CeedInt optElems = blksize/(Q1d*Q1d);
977       const CeedInt elemsPerBlock = optElems>0?optElems:1;
978       const CeedInt gridsize = nelem/elemsPerBlock + ( (
979                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
980       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d,
981                                  elemsPerBlock, weightargs);
982       CeedChkBackend(ierr);
983     } else if (dim == 3) {
984       const CeedInt gridsize = nelem;
985       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
986                                  weightargs);
987       CeedChkBackend(ierr);
988     }
989   } break;
990   // LCOV_EXCL_START
991   // Evaluate the divergence to/from the quadrature points
992   case CEED_EVAL_DIV:
993     return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
994   // Evaluate the curl to/from the quadrature points
995   case CEED_EVAL_CURL:
996     return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
997   // Take no action, BasisApply should not have been called
998   case CEED_EVAL_NONE:
999     return CeedError(ceed, CEED_ERROR_BACKEND,
1000                      "CEED_EVAL_NONE does not make sense in this context");
1001     // LCOV_EXCL_STOP
1002   }
1003 
1004   // Restore vectors
1005   if (emode != CEED_EVAL_WEIGHT) {
1006     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChkBackend(ierr);
1007   }
1008   ierr = CeedVectorRestoreArray(v, &d_v); CeedChkBackend(ierr);
1009   return CEED_ERROR_SUCCESS;
1010 }
1011 
1012 //------------------------------------------------------------------------------
1013 // Destroy basis
1014 //------------------------------------------------------------------------------
1015 static int CeedBasisDestroy_Hip_shared(CeedBasis basis) {
1016   int ierr;
1017   Ceed ceed;
1018   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
1019 
1020   CeedBasis_Hip_shared *data;
1021   ierr = CeedBasisGetData(basis, &data); CeedChkBackend(ierr);
1022 
1023   CeedChk_Hip(ceed, hipModuleUnload(data->module));
1024 
1025   ierr = hipFree(data->d_qweight1d); CeedChk_Hip(ceed, ierr);
1026   ierr = hipFree(data->d_interp1d); CeedChk_Hip(ceed, ierr);
1027   ierr = hipFree(data->d_grad1d); CeedChk_Hip(ceed, ierr);
1028   ierr = hipFree(data->d_collograd1d); CeedChk_Hip(ceed, ierr);
1029 
1030   ierr = CeedFree(&data); CeedChkBackend(ierr);
1031 
1032   return CEED_ERROR_SUCCESS;
1033 }
1034 
1035 //------------------------------------------------------------------------------
1036 // Create tensor basis
1037 //------------------------------------------------------------------------------
1038 int CeedBasisCreateTensorH1_Hip_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
1039                                        const CeedScalar *interp1d,
1040                                        const CeedScalar *grad1d,
1041                                        const CeedScalar *qref1d,
1042                                        const CeedScalar *qweight1d,
1043                                        CeedBasis basis) {
1044   int ierr;
1045   Ceed ceed;
1046   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
1047   CeedBasis_Hip_shared *data;
1048   ierr = CeedCalloc(1, &data); CeedChkBackend(ierr);
1049 
1050   // Copy basis data to GPU
1051   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
1052   ierr = hipMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Hip(ceed, ierr);
1053   ierr = hipMemcpy(data->d_qweight1d, qweight1d, qBytes,
1054                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1055 
1056   const CeedInt iBytes = qBytes * P1d;
1057   ierr = hipMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Hip(ceed, ierr);
1058   ierr = hipMemcpy(data->d_interp1d, interp1d, iBytes,
1059                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1060 
1061   ierr = hipMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Hip(ceed, ierr);
1062   ierr = hipMemcpy(data->d_grad1d, grad1d, iBytes,
1063                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1064 
1065   // Compute collocated gradient and copy to GPU
1066   data->d_collograd1d = NULL;
1067   if (dim == 3 && Q1d >= P1d) {
1068     CeedScalar *collograd1d;
1069     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChkBackend(ierr);
1070     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChkBackend(ierr);
1071     ierr = hipMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
1072     CeedChk_Hip(ceed, ierr);
1073     ierr = hipMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
1074                      hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1075     ierr = CeedFree(&collograd1d); CeedChkBackend(ierr);
1076   }
1077 
1078   // Set number of threads per block for basis kernels
1079   CeedInt ncomp;
1080   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr);
1081   ierr = ComputeBasisThreadBlockSizes(dim, P1d, Q1d, ncomp, data->blksizes);
1082   CeedChkBackend(ierr);
1083 
1084   // Compile basis kernels
1085   ierr = CeedCompileHip(ceed, kernelsShared, &data->module, 11,
1086                         "Q1D", Q1d,
1087                         "P1D", P1d,
1088                         "T1D", CeedIntMax(Q1d, P1d),
1089                         "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
1090                             Q1d : P1d, dim),
1091                         "BASIS_DIM", dim,
1092                         "BASIS_NCOMP", ncomp,
1093                         "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
1094                         "BASIS_NQPT", CeedIntPow(Q1d, dim),
1095                         "INTERP_BLKSIZE", data->blksizes[0],
1096                         "GRAD_BLKSIZE", data->blksizes[1],
1097                         "WEIGHT_BLKSIZE", data->blksizes[2]
1098                        ); CeedChkBackend(ierr);
1099   ierr = CeedGetKernelHip(ceed, data->module, "interp", &data->interp);
1100   CeedChkBackend(ierr);
1101   ierr = CeedGetKernelHip(ceed, data->module, "grad", &data->grad);
1102   CeedChkBackend(ierr);
1103   ierr = CeedGetKernelHip(ceed, data->module, "weight", &data->weight);
1104   CeedChkBackend(ierr);
1105 
1106   ierr = CeedBasisSetData(basis, data); CeedChkBackend(ierr);
1107 
1108   // Register backend functions
1109   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
1110                                 CeedBasisApplyTensor_Hip_shared);
1111   CeedChkBackend(ierr);
1112   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
1113                                 CeedBasisDestroy_Hip_shared); CeedChkBackend(ierr);
1114   return CEED_ERROR_SUCCESS;
1115 }
1116 //------------------------------------------------------------------------------
1117