xref: /libCEED/backends/hip-shared/ceed-hip-shared-basis.c (revision 3d576824e8d990e1f48c6609089904bee9170514)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed.h>
18 #include <ceed-backend.h>
19 #include <hip/hip_runtime.h>
20 #include <stddef.h>
21 #include "ceed-hip-shared.h"
22 #include "../hip/ceed-hip.h"
23 #include "../hip/ceed-hip-compile.h"
24 
25 //------------------------------------------------------------------------------
26 // Shared mem kernels
27 //------------------------------------------------------------------------------
28 // *INDENT-OFF*
29 static const char *kernelsShared = QUOTE(
30 
31 //------------------------------------------------------------------------------
32 // Sum input into output
33 //------------------------------------------------------------------------------
34 inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
35   for (int i = 0; i < P1D; i++)
36     r_V[i] += r_U[i];
37 }
38 
39 //------------------------------------------------------------------------------
40 // 1D
41 //------------------------------------------------------------------------------
42 
43 //------------------------------------------------------------------------------
44 // Read DoFs
45 //------------------------------------------------------------------------------
46 inline __device__ void readDofs1d(const int elem, const int tidx,
47                                   const int tidy, const int tidz,const int comp,
48                                   const int nelem, const CeedScalar *d_U,
49                                   CeedScalar *slice) {
50   for (int i = 0; i < P1D; i++)
51     slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem];
52   for (int i = P1D; i < Q1D; i++)
53     slice[i + tidz*T1D] = 0.0;
54 }
55 
56 //------------------------------------------------------------------------------
57 // Write DoFs
58 //------------------------------------------------------------------------------
59 inline __device__ void writeDofs1d(const int elem, const int tidx,
60                                    const int tidy, const int comp,
61                                    const int nelem, const CeedScalar &r_V,
62                                    CeedScalar *d_V) {
63   if (tidx<P1D)
64     d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V;
65 }
66 
67 //------------------------------------------------------------------------------
68 // Read quadrature point data
69 //------------------------------------------------------------------------------
70 inline __device__ void readQuads1d(const int elem, const int tidx,
71                                    const int tidy, const int tidz, const int comp,
72                                    const int dim, const int nelem,
73                                    const CeedScalar *d_U, CeedScalar *slice) {
74   for (int i = 0; i < Q1D; i++)
75     slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
76                             dim*BASIS_NCOMP*nelem*Q1D];
77   for (int i = Q1D; i < P1D; i++)
78     slice[i + tidz*T1D] = 0.0;
79 }
80 
81 //------------------------------------------------------------------------------
82 // Write quadrature point data
83 //------------------------------------------------------------------------------
84 inline __device__ void writeQuads1d(const int elem, const int tidx,
85                                     const int tidy, const int comp,
86                                     const int dim, const int nelem,
87                                     const CeedScalar &r_V, CeedScalar *d_V) {
88   if (tidx<Q1D)
89     d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
90 }
91 
92 //------------------------------------------------------------------------------
93 // 1D tensor contraction
94 //------------------------------------------------------------------------------
95 inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
96                                    const int tidy, const int tidz,
97                                    const CeedScalar &U, const CeedScalar *B,
98                                    CeedScalar &V) {
99   V = 0.0;
100   for (int i = 0; i < P1D; ++i)
101     V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction
102 }
103 
104 //------------------------------------------------------------------------------
105 // 1D transpose tensor contraction
106 //------------------------------------------------------------------------------
107 inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
108     const int tidy, const int tidz,
109     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
110   V = 0.0;
111   for (int i = 0; i < Q1D; ++i)
112     V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction
113 }
114 
115 //------------------------------------------------------------------------------
116 // 1D interpolate to quadrature points
117 //------------------------------------------------------------------------------
118 inline __device__ void interp1d(const CeedInt nelem, const int transpose,
119                                 const CeedScalar *c_B,
120                                 const CeedScalar *__restrict__ d_U,
121                                 CeedScalar *__restrict__ d_V,
122                                 CeedScalar *slice) {
123   CeedScalar r_V;
124   CeedScalar r_t;
125 
126   const int tidx = threadIdx.x;
127   const int tidy = threadIdx.y;
128   const int tidz = threadIdx.z;
129 
130 
131   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
132        elem += gridDim.x*blockDim.z) {
133     for (int comp = 0; comp < BASIS_NCOMP; comp++) {
134       if (!transpose) {
135         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
136         ContractX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
137         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
138       } else {
139         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
140         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
141         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
142       }
143     }
144   }
145 }
146 
147 //------------------------------------------------------------------------------
148 // 1D derivatives at quadrature points
149 //------------------------------------------------------------------------------
150 inline __device__ void grad1d(const CeedInt nelem, const int transpose,
151                               const CeedScalar *c_B, const CeedScalar *c_G,
152                               const CeedScalar *__restrict__ d_U,
153                               CeedScalar *__restrict__ d_V,
154                               CeedScalar *slice) {
155   CeedScalar r_U;
156   CeedScalar r_V;
157 
158   const int tidx = threadIdx.x;
159   const int tidy = threadIdx.y;
160   const int tidz = threadIdx.z;
161   int dim;
162 
163   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
164        elem += gridDim.x*blockDim.z) {
165     for(int comp = 0; comp < BASIS_NCOMP; comp++) {
166       if (!transpose) {
167         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
168         ContractX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
169         dim = 0;
170         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
171       } else {
172         dim = 0;
173         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
174         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
175         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
176       }
177     }
178   }
179 }
180 
181 //------------------------------------------------------------------------------
182 // 1D Quadrature weights
183 //------------------------------------------------------------------------------
184 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
185                          CeedScalar *w) {
186   const int tid = threadIdx.x;
187   const CeedScalar weight = qweight1d[tid];
188   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
189        elem += gridDim.x*blockDim.y) {
190     const int ind = elem*Q1D + tid;
191     w[ind] = weight;
192   }
193 }
194 
195 //------------------------------------------------------------------------------
196 // 2D
197 //------------------------------------------------------------------------------
198 
199 //------------------------------------------------------------------------------
200 // Read DoFs
201 //------------------------------------------------------------------------------
202 inline __device__ void readDofs2d(const int elem, const int tidx,
203                                   const int tidy, const int comp,
204                                   const int nelem, const CeedScalar *d_U,
205                                   CeedScalar &U) {
206   U = (tidx<P1D && tidy<P1D) ?
207       d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0;
208 }
209 
210 //------------------------------------------------------------------------------
211 // Write DoFs
212 //------------------------------------------------------------------------------
213 inline __device__ void writeDofs2d(const int elem, const int tidx,
214                                    const int tidy, const int comp,
215                                    const int nelem, const CeedScalar &r_V,
216                                    CeedScalar *d_V) {
217   if (tidx<P1D && tidy<P1D)
218     d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V;
219 }
220 
221 //------------------------------------------------------------------------------
222 // Read quadrature point data
223 //------------------------------------------------------------------------------
224 inline __device__ void readQuads2d(const int elem, const int tidx,
225                                    const int tidy, const int comp,
226                                    const int dim, const int nelem,
227                                    const CeedScalar *d_U, CeedScalar &U ) {
228   U = (tidx<Q1D && tidy<Q1D) ?
229       d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
230       dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0;
231 }
232 
233 //------------------------------------------------------------------------------
234 // Write quadrature point data
235 //------------------------------------------------------------------------------
236 inline __device__ void writeQuads2d(const int elem, const int tidx,
237                                     const int tidy, const int comp,
238                                     const int dim, const int nelem,
239                                     const CeedScalar &r_V, CeedScalar *d_V) {
240   if (tidx<Q1D && tidy<Q1D)
241     d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
242     dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V;
243 }
244 
245 //------------------------------------------------------------------------------
246 // 2D tensor contraction x
247 //------------------------------------------------------------------------------
248 inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
249                                    const int tidy, const int tidz,
250                                    const CeedScalar &U, const CeedScalar *B,
251                                    CeedScalar &V) {
252   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
253   __syncthreads();
254   V = 0.0;
255   if (tidx < Q1D)
256     for (int i = 0; i < P1D; ++i)
257       V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
258   __syncthreads();
259 }
260 
261 //------------------------------------------------------------------------------
262 // 2D tensor contraction y
263 //------------------------------------------------------------------------------
264 inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
265                                    const int tidy, const int tidz,
266                                    const CeedScalar &U, const CeedScalar *B,
267                                    CeedScalar &V) {
268   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
269   __syncthreads();
270   V = 0.0;
271   if (tidy < Q1D)
272     for (int i = 0; i < P1D; ++i)
273       V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
274   __syncthreads();
275 }
276 
277 //------------------------------------------------------------------------------
278 // 2D transpose tensor contraction y
279 //------------------------------------------------------------------------------
280 inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
281     const int tidy, const int tidz,
282     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
283   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
284   __syncthreads();
285   V = 0.0;
286   if (tidy < P1D)
287     for (int i = 0; i < Q1D; ++i)
288       V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
289   __syncthreads();
290 }
291 
292 //------------------------------------------------------------------------------
293 // 2D transpose tensor contraction x
294 //------------------------------------------------------------------------------
295 inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
296     const int tidy, const int tidz,
297     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
298   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
299   __syncthreads();
300   V = 0.0;
301   if (tidx < P1D)
302     for (int i = 0; i < Q1D; ++i)
303       V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
304   __syncthreads();
305 }
306 
307 //------------------------------------------------------------------------------
308 // 2D interpolate to quadrature points
309 //------------------------------------------------------------------------------
310 inline __device__ void interp2d(const CeedInt nelem, const int transpose,
311                                 const CeedScalar *c_B,
312                                 const CeedScalar *__restrict__ d_U,
313                                 CeedScalar *__restrict__ d_V,
314                                 CeedScalar *slice) {
315   CeedScalar r_V;
316   CeedScalar r_t;
317 
318   const int tidx = threadIdx.x;
319   const int tidy = threadIdx.y;
320   const int tidz = threadIdx.z;
321   const int blockElem = tidz/BASIS_NCOMP;
322   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
323   const int comp = tidz%BASIS_NCOMP;
324 
325   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
326        elem += gridDim.x*elemsPerBlock) {
327     const int comp = tidz%BASIS_NCOMP;
328     r_V = 0.0;
329     r_t = 0.0;
330     if (!transpose) {
331       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
332       ContractX2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
333       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
334       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
335     } else {
336       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
337       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
338       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
339       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
340     }
341   }
342 }
343 
344 //------------------------------------------------------------------------------
345 // 2D derivatives at quadrature points
346 //------------------------------------------------------------------------------
347 inline __device__ void grad2d(const CeedInt nelem, const int transpose,
348                               const CeedScalar *c_B, const CeedScalar *c_G,
349                               const CeedScalar *__restrict__ d_U,
350                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
351   CeedScalar r_U;
352   CeedScalar r_V;
353   CeedScalar r_t;
354 
355   const int tidx = threadIdx.x;
356   const int tidy = threadIdx.y;
357   const int tidz = threadIdx.z;
358   const int blockElem = tidz/BASIS_NCOMP;
359   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
360   const int comp = tidz%BASIS_NCOMP;
361   int dim;
362 
363   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
364        elem += gridDim.x*elemsPerBlock) {
365     if (!transpose) {
366       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
367       ContractX2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
368       ContractY2d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
369       dim = 0;
370       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
371       ContractX2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
372       ContractY2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
373       dim = 1;
374       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
375     } else {
376       dim = 0;
377       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
378       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
379       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
380       dim = 1;
381       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
382       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
383       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
384       r_V += r_U;
385       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
386     }
387   }
388 }
389 
390 //------------------------------------------------------------------------------
391 // 2D quadrature weights
392 //------------------------------------------------------------------------------
393 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
394                          CeedScalar *w) {
395   const int i = threadIdx.x;
396   const int j = threadIdx.y;
397   const CeedScalar weight = qweight1d[i]*qweight1d[j];
398   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
399        elem += gridDim.x*blockDim.z) {
400     const int ind = elem*Q1D*Q1D + i + j*Q1D;
401     w[ind] = weight;
402   }
403 }
404 
405 //------------------------------------------------------------------------------
406 // 3D
407 //------------------------------------------------------------------------------
408 
409 //------------------------------------------------------------------------------
410 // Read DoFs
411 //------------------------------------------------------------------------------
412 inline __device__ void readDofs3d(const int elem, const int tidx,
413                                   const int tidy, const int comp,
414                                   const int nelem, const CeedScalar *d_U,
415                                   CeedScalar *r_U) {
416   for (int i = 0; i < P1D; i++)
417     r_U[i] = (tidx < P1D && tidy < P1D) ?
418               d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
419                   comp*P1D*P1D*P1D*nelem] : 0.0;
420   for (int i = P1D; i < Q1D; i++)
421     r_U[i] = 0.0;
422 }
423 
424 //------------------------------------------------------------------------------
425 // Write DoFs
426 //------------------------------------------------------------------------------
427 inline __device__ void writeDofs3d(const int elem, const int tidx,
428                                    const int tidy, const int comp,
429                                    const int nelem, const CeedScalar *r_V,
430                                    CeedScalar *d_V) {
431   if (tidx < P1D && tidy < P1D) {
432     for (int i = 0; i < P1D; i++)
433       d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
434           comp*P1D*P1D*P1D*nelem] = r_V[i];
435   }
436 }
437 
438 //------------------------------------------------------------------------------
439 // Read quadrature point data
440 //------------------------------------------------------------------------------
441 inline __device__ void readQuads3d(const int elem, const int tidx,
442                                    const int tidy, const int comp,
443                                    const int dim, const int nelem,
444                                    const CeedScalar *d_U, CeedScalar *r_U) {
445   for (int i = 0; i < Q1D; i++)
446     r_U[i] = (tidx < Q1D && tidy < Q1D) ?
447               d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
448               comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0;
449   for (int i = Q1D; i < P1D; i++)
450     r_U[i] = 0.0;
451 }
452 
453 //------------------------------------------------------------------------------
454 // Write quadrature point data
455 //------------------------------------------------------------------------------
456 inline __device__ void writeQuads3d(const int elem, const int tidx,
457                                     const int tidy, const int comp,
458                                     const int dim, const int nelem,
459                                     const CeedScalar *r_V, CeedScalar *d_V) {
460   if (tidx < Q1D && tidy < Q1D) {
461     for (int i = 0; i < Q1D; i++)
462       d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
463           dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i];
464   }
465 }
466 
467 //------------------------------------------------------------------------------
468 // 3D tensor contract x
469 //------------------------------------------------------------------------------
470 inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
471                                    const int tidy, const int tidz,
472                                    const CeedScalar *U,
473                                    const CeedScalar *B,
474                                    CeedScalar *V) {
475   for (int k = 0; k < P1D; ++k) {
476     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
477     __syncthreads();
478     V[k] = 0.0;
479     if (tidx < Q1D && tidy < P1D)
480       for (int i = 0; i < P1D; ++i)
481         V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
482     __syncthreads();
483   }
484 }
485 
486 //------------------------------------------------------------------------------
487 // 3D tensor contract y
488 //------------------------------------------------------------------------------
489 inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
490                                    const int tidy, const int tidz,
491                                    const CeedScalar *U,
492                                    const CeedScalar *B,
493                                    CeedScalar *V) {
494   for (int k = 0; k < P1D; ++k) {
495     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
496     __syncthreads();
497     V[k] = 0.0;
498     if (tidx < Q1D && tidy < Q1D)
499       for (int i = 0; i < P1D; ++i)
500         V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
501     __syncthreads();
502   }
503 }
504 
505 //------------------------------------------------------------------------------
506 // 3D tensor contract z
507 //------------------------------------------------------------------------------
508 inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
509                                    const int tidy, const int tidz,
510                                    const CeedScalar *U,
511                                    const CeedScalar *B,
512                                    CeedScalar *V) {
513   for (int k = 0; k < Q1D; ++k) {
514     V[k] = 0.0;
515     if (tidx < Q1D && tidy < Q1D)
516       for (int i = 0; i < P1D; ++i)
517         V[k] += B[i + k*P1D] * U[i]; // Contract z direction
518   }
519   for (int k = Q1D; k < P1D; ++k)
520     V[k] = 0.0;
521 }
522 
523 //------------------------------------------------------------------------------
524 // 3D transpose tensor contract z
525 //------------------------------------------------------------------------------
526 inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
527                                             const int tidy, const int tidz,
528                                             const CeedScalar *U,
529                                             const CeedScalar *B,
530                                             CeedScalar *V) {
531   for (int k = 0; k < P1D; ++k) {
532     V[k] = 0.0;
533     if (tidx < Q1D && tidy < Q1D)
534       for (int i = 0; i < Q1D; ++i)
535         V[k] += B[k + i*P1D] * U[i]; // Contract z direction
536   }
537   for (int k = P1D; k < Q1D; ++k)
538     V[k] = 0.0;
539 }
540 
541 //------------------------------------------------------------------------------
542 // 3D transpose tensor contract y
543 //------------------------------------------------------------------------------
544 inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
545                                             const int tidy, const int tidz,
546                                             const CeedScalar *U,
547                                             const CeedScalar *B,
548                                             CeedScalar *V) {
549   for (int k = 0; k < P1D; ++k) {
550     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
551     __syncthreads();
552     V[k] = 0.0;
553     if (tidx < Q1D && tidy < P1D)
554       for (int i = 0; i < Q1D; ++i)
555         V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
556     __syncthreads();
557   }
558 }
559 
560 //------------------------------------------------------------------------------
561 // 3D transpose tensor contract x
562 //------------------------------------------------------------------------------
563 inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
564                                             const int tidy, const int tidz,
565                                             const CeedScalar *U,
566                                             const CeedScalar *B,
567                                             CeedScalar *V) {
568   for (int k = 0; k < P1D; ++k) {
569     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
570     __syncthreads();
571     V[k] = 0.0;
572     if (tidx < P1D && tidy < P1D)
573       for (int i = 0; i < Q1D; ++i)
574         V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
575     __syncthreads();
576   }
577 }
578 
579 //------------------------------------------------------------------------------
580 // 3D interpolate to quadrature points
581 //------------------------------------------------------------------------------
582 inline __device__ void interp3d(const CeedInt nelem, const int transpose,
583                                 const CeedScalar *c_B,
584                                 const CeedScalar *__restrict__ d_U,
585                                 CeedScalar *__restrict__ d_V,
586                                 CeedScalar *slice) {
587   CeedScalar r_V[T1D];
588   CeedScalar r_t[T1D];
589 
590   const int tidx = threadIdx.x;
591   const int tidy = threadIdx.y;
592   const int tidz = threadIdx.z;
593   const int blockElem = tidz/BASIS_NCOMP;
594   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
595   const int comp = tidz%BASIS_NCOMP;
596 
597   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
598        elem += gridDim.x*elemsPerBlock) {
599     for (int i = 0; i < T1D; ++i) {
600       r_V[i] = 0.0;
601       r_t[i] = 0.0;
602     }
603     if (!transpose) {
604       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
605       ContractX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
606       ContractY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
607       ContractZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
608       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
609     } else {
610       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
611       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
612       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
613       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
614       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
615     }
616   }
617 }
618 
619 //------------------------------------------------------------------------------
620 // 3D derivatives at quadrature points
621 //------------------------------------------------------------------------------
622 inline __device__ void grad3d(const CeedInt nelem, const int transpose,
623                               const CeedScalar *c_B, const CeedScalar *c_G,
624                               const CeedScalar *__restrict__ d_U,
625                               CeedScalar *__restrict__ d_V,
626                               CeedScalar *slice) {
627   // Use P1D for one of these
628   CeedScalar r_U[T1D];
629   CeedScalar r_V[T1D];
630   CeedScalar r_t[T1D];
631 
632   const int tidx = threadIdx.x;
633   const int tidy = threadIdx.y;
634   const int tidz = threadIdx.z;
635   const int blockElem = tidz/BASIS_NCOMP;
636   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
637   const int comp = tidz%BASIS_NCOMP;
638   int dim;
639 
640   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
641        elem += gridDim.x*elemsPerBlock) {
642     for (int i = 0; i < T1D; ++i) {
643       r_U[i] = 0.0;
644       r_V[i] = 0.0;
645       r_t[i] = 0.0;
646     }
647     if (!transpose) {
648       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
649       ContractX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
650       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
651       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
652       dim = 0;
653       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
654       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
655       ContractY3d(slice, tidx, tidy, tidz, r_V, c_G, r_t);
656       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_B, r_V);
657       dim = 1;
658       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
659       ContractX3d(slice, tidx, tidy, tidz, r_U, c_B, r_V);
660       ContractY3d(slice, tidx, tidy, tidz, r_V, c_B, r_t);
661       ContractZ3d(slice, tidx, tidy, tidz, r_t, c_G, r_V);
662       dim = 2;
663       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
664     } else {
665       dim = 0;
666       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
667       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
668       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
669       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_G, r_V);
670       dim = 1;
671       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
672       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
673       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_G, r_U);
674       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
675       add(r_V, r_t);
676       dim = 2;
677       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
678       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, c_G, r_t);
679       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, c_B, r_U);
680       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, c_B, r_t);
681       add(r_V, r_t);
682       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
683     }
684   }
685 }
686 
687 //------------------------------------------------------------------------------
688 // 3D quadrature weights
689 //------------------------------------------------------------------------------
690 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
691                          CeedScalar *w) {
692   const int i = threadIdx.x;
693   const int j = threadIdx.y;
694   const int k = threadIdx.z;
695   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
696   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
697     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
698     w[ind] = weight;
699   }
700 }
701 
702 
703 //------------------------------------------------------------------------------
704 // Basis kernels
705 //------------------------------------------------------------------------------
706 
707 //------------------------------------------------------------------------------
708 // Interp kernel by dim
709 //------------------------------------------------------------------------------
710 extern "C" __launch_bounds__(INTERP_BLKSIZE) __global__ void interp(
711                                   const CeedInt nelem, const int transpose,
712                                   const CeedScalar *c_B,
713                                   const CeedScalar *__restrict__ d_U,
714                                   CeedScalar *__restrict__ d_V) {
715   HIP_DYNAMIC_SHARED( double, slice)
716   if (BASIS_DIM == 1) {
717     interp1d(nelem, transpose, c_B, d_U, d_V, slice);
718   } else if (BASIS_DIM == 2) {
719     interp2d(nelem, transpose, c_B, d_U, d_V, slice);
720   } else if (BASIS_DIM == 3) {
721     interp3d(nelem, transpose, c_B, d_U, d_V, slice);
722   }
723 }
724 
725 //------------------------------------------------------------------------------
726 // Grad kernel by dim
727 //------------------------------------------------------------------------------
728 extern "C" __launch_bounds__(GRAD_BLKSIZE) __global__ void grad(const CeedInt nelem,
729                                 const int transpose,
730                                 const CeedScalar *c_B, const CeedScalar *c_G,
731                                 const CeedScalar *__restrict__ d_U,
732                                 CeedScalar *__restrict__ d_V) {
733   HIP_DYNAMIC_SHARED( double, slice)
734   if (BASIS_DIM == 1) {
735     grad1d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
736   } else if (BASIS_DIM == 2) {
737     grad2d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
738   } else if (BASIS_DIM == 3) {
739     grad3d(nelem, transpose, c_B, c_G, d_U, d_V, slice);
740   }
741 }
742 
743 //------------------------------------------------------------------------------
744 // Weight kernels by dim
745 //------------------------------------------------------------------------------
746 extern "C" __launch_bounds__(WEIGHT_BLKSIZE) __global__ void weight(const CeedInt nelem,
747                                   const CeedScalar *__restrict__ qweight1d,
748                                   CeedScalar *__restrict__ v) {
749   if (BASIS_DIM == 1) {
750     weight1d(nelem, qweight1d, v);
751   } else if (BASIS_DIM == 2) {
752     weight2d(nelem, qweight1d, v);
753   } else if (BASIS_DIM == 3) {
754     weight3d(nelem, qweight1d, v);
755   }
756 }
757 
758 );
759 // *INDENT-ON*
760 
761 //------------------------------------------------------------------------------
762 // Compute a block size based on required minimum threads
763 //------------------------------------------------------------------------------
764 static CeedInt ComputeBlockSizeFromRequirement(const CeedInt required) {
765   CeedInt maxSize = 1024;    // Max total threads per block
766   CeedInt currentSize = 64;  // Start with one group
767 
768   while(currentSize < maxSize) {
769     if (currentSize > required)
770       break;
771     else
772       currentSize = currentSize * 2;
773   }
774   return currentSize;
775 }
776 
777 //------------------------------------------------------------------------------
778 // Compute required thread block sizes for basis kernels given P, Q, dim, and
779 // ncomp
780 //------------------------------------------------------------------------------
781 static int ComputeBasisThreadBlockSizes(const CeedInt dim, const CeedInt P1d,
782                                         const CeedInt Q1d,
783                                         const CeedInt ncomp, CeedInt *blksizes) {
784 
785   // Note that this will use the same block sizes for all dimensions when compiling,
786   // but as each basis object is defined for a particular dimension, we will never
787   // call any kernels except the ones for the dimension for which we have computed the
788   // block sizes.
789   const CeedInt thread1d = CeedIntMax(P1d, Q1d);
790   switch (dim) {
791   case 1: {
792     // Interp kernels:
793     blksizes[0] = 256;
794 
795     // Grad kernels:
796     blksizes[1] = 256;
797 
798     // Weight kernels:
799     blksizes[2] = 256;
800 
801   } break;
802   case 2: {
803     // Interp kernels:
804     CeedInt required = thread1d * thread1d * ncomp;
805     blksizes[0]  = ComputeBlockSizeFromRequirement(required);
806 
807     // Grad kernels: currently use same required minimum threads
808     blksizes[1]  = ComputeBlockSizeFromRequirement(required);
809 
810     // Weight kernels:
811     required = CeedIntMax(64, Q1d * Q1d);
812     blksizes[2]  = ComputeBlockSizeFromRequirement(required);
813 
814   } break;
815   case 3: {
816     // Interp kernels:
817     CeedInt required = thread1d * thread1d * ncomp;
818     blksizes[0]  = ComputeBlockSizeFromRequirement(required);
819 
820     // Grad kernels: currently use same required minimum threads
821     blksizes[1]  = ComputeBlockSizeFromRequirement(required);
822 
823     // Weight kernels:
824     required = Q1d * Q1d * Q1d;
825     blksizes[2]  = ComputeBlockSizeFromRequirement(required);
826   }
827   }
828 
829   return 0;
830 }
831 
832 //------------------------------------------------------------------------------
833 // Apply basis
834 //------------------------------------------------------------------------------
835 int CeedBasisApplyTensor_Hip_shared(CeedBasis basis, const CeedInt nelem,
836                                     CeedTransposeMode tmode,
837                                     CeedEvalMode emode, CeedVector u,
838                                     CeedVector v) {
839   int ierr;
840   Ceed ceed;
841   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
842   Ceed_Hip_shared *ceed_Hip;
843   CeedGetData(ceed, &ceed_Hip); CeedChk(ierr);
844   CeedBasis_Hip_shared *data;
845   CeedBasisGetData(basis, &data); CeedChk(ierr);
846   const CeedInt transpose = tmode == CEED_TRANSPOSE;
847   CeedInt dim, ncomp;
848   ierr = CeedBasisGetDimension(basis, &dim); CeedChk(ierr);
849   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
850 
851   // Read vectors
852   const CeedScalar *d_u;
853   CeedScalar *d_v;
854   if (emode != CEED_EVAL_WEIGHT) {
855     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChk(ierr);
856   }
857   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChk(ierr);
858 
859   // Clear v for transpose mode
860   if (tmode == CEED_TRANSPOSE) {
861     CeedInt length;
862     ierr = CeedVectorGetLength(v, &length); CeedChk(ierr);
863     ierr = hipMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChk(ierr);
864   }
865 
866   // Apply basis operation
867   switch (emode) {
868   case CEED_EVAL_INTERP: {
869     CeedInt P1d, Q1d;
870     CeedInt blksize = data->blksizes[0];
871     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
872     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
873     CeedInt thread1d = CeedIntMax(Q1d, P1d);
874     ierr = CeedHipInitInterp(data->d_interp1d, P1d, Q1d, &data->c_B);
875     CeedChk(ierr);
876     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
877                           &d_u, &d_v
878                          };
879     if (dim == 1) {
880       CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64;
881       elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1;
882       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
883                                              ? 1 : 0 );
884       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
885       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, 1,
886                                        elemsPerBlock, sharedMem,
887                                        interpargs); CeedChk(ierr);
888     } else if (dim == 2) {
889       // Check if required threads is small enough to do multiple elems
890       const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1);
891       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
892                                              ? 1 : 0 );
893       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
894       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d,
895                                        ncomp*elemsPerBlock, sharedMem,
896                                        interpargs); CeedChk(ierr);
897     } else if (dim == 3) {
898       CeedInt elemsPerBlock = 1;
899       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
900                                              ? 1 : 0 );
901       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
902       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d,
903                                        ncomp*elemsPerBlock, sharedMem,
904                                        interpargs); CeedChk(ierr);
905     }
906   } break;
907   case CEED_EVAL_GRAD: {
908     CeedInt P1d, Q1d;
909     CeedInt blksize = data->blksizes[1];
910     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChk(ierr);
911     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
912     CeedInt thread1d = CeedIntMax(Q1d, P1d);
913     ierr = CeedHipInitInterpGrad(data->d_interp1d, data->d_grad1d, P1d,
914                                  Q1d, &data->c_B, &data->c_G);
915     CeedChk(ierr);
916     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->c_B,
917                         &data->c_G, &d_u, &d_v
918                        };
919     if (dim == 1) {
920       CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64;
921       elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1;
922       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
923                                              ? 1 : 0 );
924       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
925       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, 1,
926                                        elemsPerBlock, sharedMem, gradargs);
927       CeedChk(ierr);
928     } else if (dim == 2) {
929       // Check if required threads is small enough to do multiple elems
930       const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1);
931       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
932                                              ? 1 : 0 );
933       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
934       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d,
935                                        ncomp*elemsPerBlock, sharedMem,
936                                        gradargs); CeedChk(ierr);
937     } else if (dim == 3) {
938       CeedInt elemsPerBlock = 1;
939       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
940                                              ? 1 : 0 );
941       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
942       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d,
943                                        ncomp*elemsPerBlock, sharedMem,
944                                        gradargs); CeedChk(ierr);
945     }
946   } break;
947   case CEED_EVAL_WEIGHT: {
948     CeedInt Q1d;
949     CeedInt blksize = data->blksizes[2];
950     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChk(ierr);
951     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
952     if (dim == 1) {
953       const CeedInt optElems = blksize/Q1d;
954       const CeedInt elemsPerBlock = optElems>0?optElems:1;
955       const CeedInt gridsize = nelem/elemsPerBlock + ( (
956                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
957       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d,
958                                  elemsPerBlock, 1, weightargs);
959       CeedChk(ierr);
960     } else if (dim == 2) {
961       const CeedInt optElems = blksize/(Q1d*Q1d);
962       const CeedInt elemsPerBlock = optElems>0?optElems:1;
963       const CeedInt gridsize = nelem/elemsPerBlock + ( (
964                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
965       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d,
966                                  elemsPerBlock, weightargs);
967       CeedChk(ierr);
968     } else if (dim == 3) {
969       const CeedInt gridsize = nelem;
970       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
971                                  weightargs);
972       CeedChk(ierr);
973     }
974   } break;
975   // LCOV_EXCL_START
976   // Evaluate the divergence to/from the quadrature points
977   case CEED_EVAL_DIV:
978     return CeedError(ceed, 1, "CEED_EVAL_DIV not supported");
979   // Evaluate the curl to/from the quadrature points
980   case CEED_EVAL_CURL:
981     return CeedError(ceed, 1, "CEED_EVAL_CURL not supported");
982   // Take no action, BasisApply should not have been called
983   case CEED_EVAL_NONE:
984     return CeedError(ceed, 1,
985                      "CEED_EVAL_NONE does not make sense in this context");
986     // LCOV_EXCL_STOP
987   }
988 
989   // Restore vectors
990   if (emode != CEED_EVAL_WEIGHT) {
991     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChk(ierr);
992   }
993   ierr = CeedVectorRestoreArray(v, &d_v); CeedChk(ierr);
994   return 0;
995 }
996 
997 //------------------------------------------------------------------------------
998 // Destroy basis
999 //------------------------------------------------------------------------------
1000 static int CeedBasisDestroy_Hip_shared(CeedBasis basis) {
1001   int ierr;
1002   Ceed ceed;
1003   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
1004 
1005   CeedBasis_Hip_shared *data;
1006   ierr = CeedBasisGetData(basis, &data); CeedChk(ierr);
1007 
1008   CeedChk_Hip(ceed, hipModuleUnload(data->module));
1009 
1010   ierr = hipFree(data->d_qweight1d); CeedChk_Hip(ceed, ierr);
1011   ierr = hipFree(data->d_interp1d); CeedChk_Hip(ceed, ierr);
1012   ierr = hipFree(data->d_grad1d); CeedChk_Hip(ceed, ierr);
1013   ierr = hipFree(data->d_collograd1d); CeedChk_Hip(ceed, ierr);
1014 
1015   ierr = CeedFree(&data); CeedChk(ierr);
1016 
1017   return 0;
1018 }
1019 
1020 //------------------------------------------------------------------------------
1021 // Create tensor basis
1022 //------------------------------------------------------------------------------
1023 int CeedBasisCreateTensorH1_Hip_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
1024                                        const CeedScalar *interp1d,
1025                                        const CeedScalar *grad1d,
1026                                        const CeedScalar *qref1d,
1027                                        const CeedScalar *qweight1d,
1028                                        CeedBasis basis) {
1029   int ierr;
1030   Ceed ceed;
1031   ierr = CeedBasisGetCeed(basis, &ceed); CeedChk(ierr);
1032   CeedBasis_Hip_shared *data;
1033   ierr = CeedCalloc(1, &data); CeedChk(ierr);
1034 
1035   // Copy basis data to GPU
1036   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
1037   ierr = hipMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Hip(ceed, ierr);
1038   ierr = hipMemcpy(data->d_qweight1d, qweight1d, qBytes,
1039                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1040 
1041   const CeedInt iBytes = qBytes * P1d;
1042   ierr = hipMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Hip(ceed, ierr);
1043   ierr = hipMemcpy(data->d_interp1d, interp1d, iBytes,
1044                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1045 
1046   ierr = hipMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Hip(ceed, ierr);
1047   ierr = hipMemcpy(data->d_grad1d, grad1d, iBytes,
1048                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1049 
1050   // Compute collocated gradient and copy to GPU
1051   data->d_collograd1d = NULL;
1052   if (dim == 3 && Q1d >= P1d) {
1053     CeedScalar *collograd1d;
1054     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChk(ierr);
1055     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChk(ierr);
1056     ierr = hipMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
1057     CeedChk_Hip(ceed, ierr);
1058     ierr = hipMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
1059                      hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1060     ierr = CeedFree(&collograd1d); CeedChk(ierr);
1061   }
1062 
1063   // Set number of threads per block for basis kernels
1064   CeedInt ncomp;
1065   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChk(ierr);
1066   ierr = ComputeBasisThreadBlockSizes(dim, P1d, Q1d, ncomp, data->blksizes);
1067   CeedChk(ierr);
1068 
1069   // Compile basis kernels
1070   ierr = CeedCompileHip(ceed, kernelsShared, &data->module, 11,
1071                         "Q1D", Q1d,
1072                         "P1D", P1d,
1073                         "T1D", CeedIntMax(Q1d, P1d),
1074                         "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
1075                             Q1d : P1d, dim),
1076                         "BASIS_DIM", dim,
1077                         "BASIS_NCOMP", ncomp,
1078                         "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
1079                         "BASIS_NQPT", CeedIntPow(Q1d, dim),
1080                         "INTERP_BLKSIZE", data->blksizes[0],
1081                         "GRAD_BLKSIZE", data->blksizes[1],
1082                         "WEIGHT_BLKSIZE", data->blksizes[2]
1083                        ); CeedChk(ierr);
1084   ierr = CeedGetKernelHip(ceed, data->module, "interp", &data->interp);
1085   CeedChk(ierr);
1086   ierr = CeedGetKernelHip(ceed, data->module, "grad", &data->grad);
1087   CeedChk(ierr);
1088   ierr = CeedGetKernelHip(ceed, data->module, "weight", &data->weight);
1089   CeedChk(ierr);
1090 
1091   ierr = CeedBasisSetData(basis, data); CeedChk(ierr);
1092 
1093   // Register backend functions
1094   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
1095                                 CeedBasisApplyTensor_Hip_shared);
1096   CeedChk(ierr);
1097   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
1098                                 CeedBasisDestroy_Hip_shared); CeedChk(ierr);
1099   return 0;
1100 }
1101 //------------------------------------------------------------------------------
1102