xref: /libCEED/backends/hip-shared/ceed-hip-shared-basis.c (revision e15f9bd09af0280c89b79924fa9af7dd2e3e30be)
1 // Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
2 // Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
3 // All Rights reserved. See files LICENSE and NOTICE for details.
4 //
5 // This file is part of CEED, a collection of benchmarks, miniapps, software
6 // libraries and APIs for efficient high-order finite element and spectral
7 // element discretizations for exascale applications. For more information and
8 // source code availability see http://github.com/ceed.
9 //
10 // The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
11 // a collaborative effort of two U.S. Department of Energy organizations (Office
12 // of Science and the National Nuclear Security Administration) responsible for
13 // the planning and preparation of a capable exascale ecosystem, including
14 // software, applications, hardware, advanced system engineering and early
15 // testbed platforms, in support of the nation's exascale computing imperative.
16 
17 #include <ceed.h>
18 #include <ceed-backend.h>
19 #include <hip/hip_runtime.h>
20 #include <stddef.h>
21 #include "ceed-hip-shared.h"
22 #include "../hip/ceed-hip.h"
23 #include "../hip/ceed-hip-compile.h"
24 
25 //------------------------------------------------------------------------------
26 // Shared mem kernels
27 //------------------------------------------------------------------------------
28 // *INDENT-OFF*
29 static const char *kernelsShared = QUOTE(
30 
31 //------------------------------------------------------------------------------
32 // Sum input into output
33 //------------------------------------------------------------------------------
34 inline __device__ void add(CeedScalar *r_V, const CeedScalar *r_U) {
35   for (int i = 0; i < P1D; i++)
36     r_V[i] += r_U[i];
37 }
38 
39 //------------------------------------------------------------------------------
40 // Load matrices for basis actions
41 //------------------------------------------------------------------------------
42 inline __device__ void loadMatrix(const CeedScalar* d_B, CeedScalar* B) {
43   CeedInt tid = threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.y*blockDim.x;
44   for (CeedInt i = tid; i < P1D*Q1D; i += blockDim.x*blockDim.y*blockDim.z)
45     B[i] = d_B[i];
46 }
47 
48 //------------------------------------------------------------------------------
49 // 1D
50 //------------------------------------------------------------------------------
51 
52 //------------------------------------------------------------------------------
53 // Read DoFs
54 //------------------------------------------------------------------------------
55 inline __device__ void readDofs1d(const int elem, const int tidx,
56                                   const int tidy, const int tidz,const int comp,
57                                   const int nelem, const CeedScalar *d_U,
58                                   CeedScalar *slice) {
59   for (int i = 0; i < P1D; i++)
60     slice[i + tidz*T1D] = d_U[i + elem*P1D + comp*P1D*nelem];
61   for (int i = P1D; i < Q1D; i++)
62     slice[i + tidz*T1D] = 0.0;
63 }
64 
65 //------------------------------------------------------------------------------
66 // Write DoFs
67 //------------------------------------------------------------------------------
68 inline __device__ void writeDofs1d(const int elem, const int tidx,
69                                    const int tidy, const int comp,
70                                    const int nelem, const CeedScalar &r_V,
71                                    CeedScalar *d_V) {
72   if (tidx<P1D)
73     d_V[tidx + elem*P1D + comp*P1D*nelem] = r_V;
74 }
75 
76 //------------------------------------------------------------------------------
77 // Read quadrature point data
78 //------------------------------------------------------------------------------
79 inline __device__ void readQuads1d(const int elem, const int tidx,
80                                    const int tidy, const int tidz, const int comp,
81                                    const int dim, const int nelem,
82                                    const CeedScalar *d_U, CeedScalar *slice) {
83   for (int i = 0; i < Q1D; i++)
84     slice[i + tidz*T1D] = d_U[i + elem*Q1D + comp*Q1D*nelem +
85                             dim*BASIS_NCOMP*nelem*Q1D];
86   for (int i = Q1D; i < P1D; i++)
87     slice[i + tidz*T1D] = 0.0;
88 }
89 
90 //------------------------------------------------------------------------------
91 // Write quadrature point data
92 //------------------------------------------------------------------------------
93 inline __device__ void writeQuads1d(const int elem, const int tidx,
94                                     const int tidy, const int comp,
95                                     const int dim, const int nelem,
96                                     const CeedScalar &r_V, CeedScalar *d_V) {
97   if (tidx<Q1D)
98     d_V[tidx + elem*Q1D + comp*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D] = r_V;
99 }
100 
101 //------------------------------------------------------------------------------
102 // 1D tensor contraction
103 //------------------------------------------------------------------------------
104 inline __device__ void ContractX1d(CeedScalar *slice, const int tidx,
105                                    const int tidy, const int tidz,
106                                    const CeedScalar &U, const CeedScalar *B,
107                                    CeedScalar &V) {
108   V = 0.0;
109   for (int i = 0; i < P1D; ++i)
110     V += B[i + tidx*P1D] * slice[i + tidz*T1D]; // Contract x direction
111 }
112 
113 //------------------------------------------------------------------------------
114 // 1D transpose tensor contraction
115 //------------------------------------------------------------------------------
116 inline __device__ void ContractTransposeX1d(CeedScalar *slice, const int tidx,
117     const int tidy, const int tidz,
118     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
119   V = 0.0;
120   for (int i = 0; i < Q1D; ++i)
121     V += B[tidx + i*P1D] * slice[i + tidz*T1D]; // Contract x direction
122 }
123 
124 //------------------------------------------------------------------------------
125 // 1D interpolate to quadrature points
126 //------------------------------------------------------------------------------
127 inline __device__ void interp1d(const CeedInt nelem, const int transpose,
128                                 const CeedScalar *s_B,
129                                 const CeedScalar *__restrict__ d_U,
130                                 CeedScalar *__restrict__ d_V,
131                                 CeedScalar *slice) {
132   CeedScalar r_V;
133   CeedScalar r_t;
134 
135   const int tidx = threadIdx.x;
136   const int tidy = threadIdx.y;
137   const int tidz = threadIdx.z;
138 
139 
140   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
141        elem += gridDim.x*blockDim.z) {
142     for (int comp = 0; comp < BASIS_NCOMP; comp++) {
143       if (!transpose) {
144         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
145         ContractX1d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
146         writeQuads1d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
147       } else {
148         readQuads1d(elem, tidx, tidy, tidz, comp, 0, nelem, d_U, slice);
149         ContractTransposeX1d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
150         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
151       }
152     }
153   }
154 }
155 
156 //------------------------------------------------------------------------------
157 // 1D derivatives at quadrature points
158 //------------------------------------------------------------------------------
159 inline __device__ void grad1d(const CeedInt nelem, const int transpose,
160                               const CeedScalar *s_B, const CeedScalar *s_G,
161                               const CeedScalar *__restrict__ d_U,
162                               CeedScalar *__restrict__ d_V,
163                               CeedScalar *slice) {
164   CeedScalar r_U;
165   CeedScalar r_V;
166 
167   const int tidx = threadIdx.x;
168   const int tidy = threadIdx.y;
169   const int tidz = threadIdx.z;
170   int dim;
171 
172   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
173        elem += gridDim.x*blockDim.z) {
174     for(int comp = 0; comp < BASIS_NCOMP; comp++) {
175       if (!transpose) {
176         readDofs1d(elem, tidx, tidy, tidz, comp, nelem, d_U, slice);
177         ContractX1d(slice, tidx, tidy, tidz, r_U, s_G, r_V);
178         dim = 0;
179         writeQuads1d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
180       } else {
181         dim = 0;
182         readQuads1d(elem, tidx, tidy, tidz, comp, dim, nelem, d_U, slice);
183         ContractTransposeX1d(slice, tidx, tidy, tidz, r_U, s_G, r_V);
184         writeDofs1d(elem, tidx, tidy, comp, nelem, r_V, d_V);
185       }
186     }
187   }
188 }
189 
190 //------------------------------------------------------------------------------
191 // 1D Quadrature weights
192 //------------------------------------------------------------------------------
193 __device__ void weight1d(const CeedInt nelem, const CeedScalar *qweight1d,
194                          CeedScalar *w) {
195   const int tid = threadIdx.x;
196   const CeedScalar weight = qweight1d[tid];
197   for (CeedInt elem = blockIdx.x*blockDim.y + threadIdx.y; elem < nelem;
198        elem += gridDim.x*blockDim.y) {
199     const int ind = elem*Q1D + tid;
200     w[ind] = weight;
201   }
202 }
203 
204 //------------------------------------------------------------------------------
205 // 2D
206 //------------------------------------------------------------------------------
207 
208 //------------------------------------------------------------------------------
209 // Read DoFs
210 //------------------------------------------------------------------------------
211 inline __device__ void readDofs2d(const int elem, const int tidx,
212                                   const int tidy, const int comp,
213                                   const int nelem, const CeedScalar *d_U,
214                                   CeedScalar &U) {
215   U = (tidx<P1D && tidy<P1D) ?
216       d_U[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] : 0.0;
217 }
218 
219 //------------------------------------------------------------------------------
220 // Write DoFs
221 //------------------------------------------------------------------------------
222 inline __device__ void writeDofs2d(const int elem, const int tidx,
223                                    const int tidy, const int comp,
224                                    const int nelem, const CeedScalar &r_V,
225                                    CeedScalar *d_V) {
226   if (tidx<P1D && tidy<P1D)
227     d_V[tidx + tidy*P1D + elem*P1D*P1D + comp*P1D*P1D*nelem] = r_V;
228 }
229 
230 //------------------------------------------------------------------------------
231 // Read quadrature point data
232 //------------------------------------------------------------------------------
233 inline __device__ void readQuads2d(const int elem, const int tidx,
234                                    const int tidy, const int comp,
235                                    const int dim, const int nelem,
236                                    const CeedScalar *d_U, CeedScalar &U ) {
237   U = (tidx<Q1D && tidy<Q1D) ?
238       d_U[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
239       dim*BASIS_NCOMP*nelem*Q1D*Q1D] : 0.0;
240 }
241 
242 //------------------------------------------------------------------------------
243 // Write quadrature point data
244 //------------------------------------------------------------------------------
245 inline __device__ void writeQuads2d(const int elem, const int tidx,
246                                     const int tidy, const int comp,
247                                     const int dim, const int nelem,
248                                     const CeedScalar &r_V, CeedScalar *d_V) {
249   if (tidx<Q1D && tidy<Q1D)
250     d_V[tidx + tidy*Q1D + elem*Q1D*Q1D + comp*Q1D*Q1D*nelem +
251     dim*BASIS_NCOMP*nelem*Q1D*Q1D] = r_V;
252 }
253 
254 //------------------------------------------------------------------------------
255 // 2D tensor contraction x
256 //------------------------------------------------------------------------------
257 inline __device__ void ContractX2d(CeedScalar *slice, const int tidx,
258                                    const int tidy, const int tidz,
259                                    const CeedScalar &U, const CeedScalar *B,
260                                    CeedScalar &V) {
261   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
262   __syncthreads();
263   V = 0.0;
264   if (tidx < Q1D)
265     for (int i = 0; i < P1D; ++i)
266       V += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
267   __syncthreads();
268 }
269 
270 //------------------------------------------------------------------------------
271 // 2D tensor contraction y
272 //------------------------------------------------------------------------------
273 inline __device__ void ContractY2d(CeedScalar *slice, const int tidx,
274                                    const int tidy, const int tidz,
275                                    const CeedScalar &U, const CeedScalar *B,
276                                    CeedScalar &V) {
277   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
278   __syncthreads();
279   V = 0.0;
280   if (tidy < Q1D)
281     for (int i = 0; i < P1D; ++i)
282       V += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
283   __syncthreads();
284 }
285 
286 //------------------------------------------------------------------------------
287 // 2D transpose tensor contraction y
288 //------------------------------------------------------------------------------
289 inline __device__ void ContractTransposeY2d(CeedScalar *slice, const int tidx,
290     const int tidy, const int tidz,
291     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
292   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
293   __syncthreads();
294   V = 0.0;
295   if (tidy < P1D)
296     for (int i = 0; i < Q1D; ++i)
297       V += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
298   __syncthreads();
299 }
300 
301 //------------------------------------------------------------------------------
302 // 2D transpose tensor contraction x
303 //------------------------------------------------------------------------------
304 inline __device__ void ContractTransposeX2d(CeedScalar *slice, const int tidx,
305     const int tidy, const int tidz,
306     const CeedScalar &U, const CeedScalar *B, CeedScalar &V) {
307   slice[tidx + tidy*T1D + tidz*T1D*T1D] = U;
308   __syncthreads();
309   V = 0.0;
310   if (tidx < P1D)
311     for (int i = 0; i < Q1D; ++i)
312       V += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
313   __syncthreads();
314 }
315 
316 //------------------------------------------------------------------------------
317 // 2D interpolate to quadrature points
318 //------------------------------------------------------------------------------
319 inline __device__ void interp2d(const CeedInt nelem, const int transpose,
320                                 const CeedScalar *s_B,
321                                 const CeedScalar *__restrict__ d_U,
322                                 CeedScalar *__restrict__ d_V,
323                                 CeedScalar *slice) {
324   CeedScalar r_V;
325   CeedScalar r_t;
326 
327   const int tidx = threadIdx.x;
328   const int tidy = threadIdx.y;
329   const int tidz = threadIdx.z;
330   const int blockElem = tidz/BASIS_NCOMP;
331   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
332   const int comp = tidz%BASIS_NCOMP;
333 
334   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
335        elem += gridDim.x*elemsPerBlock) {
336     const int comp = tidz%BASIS_NCOMP;
337     r_V = 0.0;
338     r_t = 0.0;
339     if (!transpose) {
340       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_V);
341       ContractX2d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
342       ContractY2d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
343       writeQuads2d(elem, tidx, tidy, comp, 0, nelem, r_V, d_V);
344     } else {
345       readQuads2d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
346       ContractTransposeY2d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
347       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
348       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
349     }
350   }
351 }
352 
353 //------------------------------------------------------------------------------
354 // 2D derivatives at quadrature points
355 //------------------------------------------------------------------------------
356 inline __device__ void grad2d(const CeedInt nelem, const int transpose,
357                               const CeedScalar *s_B, const CeedScalar *s_G,
358                               const CeedScalar *__restrict__ d_U,
359                               CeedScalar *__restrict__ d_V, CeedScalar *slice) {
360   CeedScalar r_U;
361   CeedScalar r_V;
362   CeedScalar r_t;
363 
364   const int tidx = threadIdx.x;
365   const int tidy = threadIdx.y;
366   const int tidz = threadIdx.z;
367   const int blockElem = tidz/BASIS_NCOMP;
368   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
369   const int comp = tidz%BASIS_NCOMP;
370   int dim;
371 
372   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
373        elem += gridDim.x*elemsPerBlock) {
374     if (!transpose) {
375       readDofs2d(elem, tidx, tidy, comp, nelem, d_U, r_U);
376       ContractX2d(slice, tidx, tidy, tidz, r_U, s_G, r_t);
377       ContractY2d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
378       dim = 0;
379       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
380       ContractX2d(slice, tidx, tidy, tidz, r_U, s_B, r_t);
381       ContractY2d(slice, tidx, tidy, tidz, r_t, s_G, r_V);
382       dim = 1;
383       writeQuads2d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
384     } else {
385       dim = 0;
386       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
387       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, s_B, r_t);
388       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, s_G, r_V);
389       dim = 1;
390       readQuads2d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
391       ContractTransposeY2d(slice, tidx, tidy, tidz, r_U, s_G, r_t);
392       ContractTransposeX2d(slice, tidx, tidy, tidz, r_t, s_B, r_U);
393       r_V += r_U;
394       writeDofs2d(elem, tidx, tidy, comp, nelem, r_V, d_V);
395     }
396   }
397 }
398 
399 //------------------------------------------------------------------------------
400 // 2D quadrature weights
401 //------------------------------------------------------------------------------
402 __device__ void weight2d(const CeedInt nelem, const CeedScalar *qweight1d,
403                          CeedScalar *w) {
404   const int i = threadIdx.x;
405   const int j = threadIdx.y;
406   const CeedScalar weight = qweight1d[i]*qweight1d[j];
407   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < nelem;
408        elem += gridDim.x*blockDim.z) {
409     const int ind = elem*Q1D*Q1D + i + j*Q1D;
410     w[ind] = weight;
411   }
412 }
413 
414 //------------------------------------------------------------------------------
415 // 3D
416 //------------------------------------------------------------------------------
417 
418 //------------------------------------------------------------------------------
419 // Read DoFs
420 //------------------------------------------------------------------------------
421 inline __device__ void readDofs3d(const int elem, const int tidx,
422                                   const int tidy, const int comp,
423                                   const int nelem, const CeedScalar *d_U,
424                                   CeedScalar *r_U) {
425   for (int i = 0; i < P1D; i++)
426     r_U[i] = (tidx < P1D && tidy < P1D) ?
427               d_U[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
428                   comp*P1D*P1D*P1D*nelem] : 0.0;
429   for (int i = P1D; i < Q1D; i++)
430     r_U[i] = 0.0;
431 }
432 
433 //------------------------------------------------------------------------------
434 // Write DoFs
435 //------------------------------------------------------------------------------
436 inline __device__ void writeDofs3d(const int elem, const int tidx,
437                                    const int tidy, const int comp,
438                                    const int nelem, const CeedScalar *r_V,
439                                    CeedScalar *d_V) {
440   if (tidx < P1D && tidy < P1D) {
441     for (int i = 0; i < P1D; i++)
442       d_V[tidx + tidy*P1D + i*P1D*P1D + elem*P1D*P1D*P1D +
443           comp*P1D*P1D*P1D*nelem] = r_V[i];
444   }
445 }
446 
447 //------------------------------------------------------------------------------
448 // Read quadrature point data
449 //------------------------------------------------------------------------------
450 inline __device__ void readQuads3d(const int elem, const int tidx,
451                                    const int tidy, const int comp,
452                                    const int dim, const int nelem,
453                                    const CeedScalar *d_U, CeedScalar *r_U) {
454   for (int i = 0; i < Q1D; i++)
455     r_U[i] = (tidx < Q1D && tidy < Q1D) ?
456               d_U[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D +
457               comp*Q1D*Q1D*Q1D*nelem + dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] : 0.0;
458   for (int i = Q1D; i < P1D; i++)
459     r_U[i] = 0.0;
460 }
461 
462 //------------------------------------------------------------------------------
463 // Write quadrature point data
464 //------------------------------------------------------------------------------
465 inline __device__ void writeQuads3d(const int elem, const int tidx,
466                                     const int tidy, const int comp,
467                                     const int dim, const int nelem,
468                                     const CeedScalar *r_V, CeedScalar *d_V) {
469   if (tidx < Q1D && tidy < Q1D) {
470     for (int i = 0; i < Q1D; i++)
471       d_V[tidx + tidy*Q1D + i*Q1D*Q1D + elem*Q1D*Q1D*Q1D + comp*Q1D*Q1D*Q1D*nelem +
472           dim*BASIS_NCOMP*nelem*Q1D*Q1D*Q1D] = r_V[i];
473   }
474 }
475 
476 //------------------------------------------------------------------------------
477 // 3D tensor contract x
478 //------------------------------------------------------------------------------
479 inline __device__ void ContractX3d(CeedScalar *slice, const int tidx,
480                                    const int tidy, const int tidz,
481                                    const CeedScalar *U,
482                                    const CeedScalar *B,
483                                    CeedScalar *V) {
484   for (int k = 0; k < P1D; ++k) {
485     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
486     __syncthreads();
487     V[k] = 0.0;
488     if (tidx < Q1D && tidy < P1D)
489       for (int i = 0; i < P1D; ++i)
490         V[k] += B[i + tidx*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
491     __syncthreads();
492   }
493 }
494 
495 //------------------------------------------------------------------------------
496 // 3D tensor contract y
497 //------------------------------------------------------------------------------
498 inline __device__ void ContractY3d(CeedScalar *slice, const int tidx,
499                                    const int tidy, const int tidz,
500                                    const CeedScalar *U,
501                                    const CeedScalar *B,
502                                    CeedScalar *V) {
503   for (int k = 0; k < P1D; ++k) {
504     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
505     __syncthreads();
506     V[k] = 0.0;
507     if (tidx < Q1D && tidy < Q1D)
508       for (int i = 0; i < P1D; ++i)
509         V[k] += B[i + tidy*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
510     __syncthreads();
511   }
512 }
513 
514 //------------------------------------------------------------------------------
515 // 3D tensor contract z
516 //------------------------------------------------------------------------------
517 inline __device__ void ContractZ3d(CeedScalar *slice, const int tidx,
518                                    const int tidy, const int tidz,
519                                    const CeedScalar *U,
520                                    const CeedScalar *B,
521                                    CeedScalar *V) {
522   for (int k = 0; k < Q1D; ++k) {
523     V[k] = 0.0;
524     if (tidx < Q1D && tidy < Q1D)
525       for (int i = 0; i < P1D; ++i)
526         V[k] += B[i + k*P1D] * U[i]; // Contract z direction
527   }
528   for (int k = Q1D; k < P1D; ++k)
529     V[k] = 0.0;
530 }
531 
532 //------------------------------------------------------------------------------
533 // 3D transpose tensor contract z
534 //------------------------------------------------------------------------------
535 inline __device__ void ContractTransposeZ3d(CeedScalar *slice, const int tidx,
536                                             const int tidy, const int tidz,
537                                             const CeedScalar *U,
538                                             const CeedScalar *B,
539                                             CeedScalar *V) {
540   for (int k = 0; k < P1D; ++k) {
541     V[k] = 0.0;
542     if (tidx < Q1D && tidy < Q1D)
543       for (int i = 0; i < Q1D; ++i)
544         V[k] += B[k + i*P1D] * U[i]; // Contract z direction
545   }
546   for (int k = P1D; k < Q1D; ++k)
547     V[k] = 0.0;
548 }
549 
550 //------------------------------------------------------------------------------
551 // 3D transpose tensor contract y
552 //------------------------------------------------------------------------------
553 inline __device__ void ContractTransposeY3d(CeedScalar *slice, const int tidx,
554                                             const int tidy, const int tidz,
555                                             const CeedScalar *U,
556                                             const CeedScalar *B,
557                                             CeedScalar *V) {
558   for (int k = 0; k < P1D; ++k) {
559     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
560     __syncthreads();
561     V[k] = 0.0;
562     if (tidx < Q1D && tidy < P1D)
563       for (int i = 0; i < Q1D; ++i)
564         V[k] += B[tidy + i*P1D] * slice[tidx + i*T1D + tidz*T1D*T1D]; // Contract y direction
565     __syncthreads();
566   }
567 }
568 
569 //------------------------------------------------------------------------------
570 // 3D transpose tensor contract x
571 //------------------------------------------------------------------------------
572 inline __device__ void ContractTransposeX3d(CeedScalar *slice, const int tidx,
573                                             const int tidy, const int tidz,
574                                             const CeedScalar *U,
575                                             const CeedScalar *B,
576                                             CeedScalar *V) {
577   for (int k = 0; k < P1D; ++k) {
578     slice[tidx + tidy*T1D + tidz*T1D*T1D] = U[k];
579     __syncthreads();
580     V[k] = 0.0;
581     if (tidx < P1D && tidy < P1D)
582       for (int i = 0; i < Q1D; ++i)
583         V[k] += B[tidx + i*P1D] * slice[i + tidy*T1D + tidz*T1D*T1D]; // Contract x direction
584     __syncthreads();
585   }
586 }
587 
588 //------------------------------------------------------------------------------
589 // 3D interpolate to quadrature points
590 //------------------------------------------------------------------------------
591 inline __device__ void interp3d(const CeedInt nelem, const int transpose,
592                                 const CeedScalar *s_B,
593                                 const CeedScalar *__restrict__ d_U,
594                                 CeedScalar *__restrict__ d_V,
595                                 CeedScalar *slice) {
596   CeedScalar r_V[T1D];
597   CeedScalar r_t[T1D];
598 
599   const int tidx = threadIdx.x;
600   const int tidy = threadIdx.y;
601   const int tidz = threadIdx.z;
602   const int blockElem = tidz/BASIS_NCOMP;
603   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
604   const int comp = tidz%BASIS_NCOMP;
605 
606   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
607        elem += gridDim.x*elemsPerBlock) {
608     for (int i = 0; i < T1D; ++i) {
609       r_V[i] = 0.0;
610       r_t[i] = 0.0;
611     }
612     if (!transpose) {
613       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_V);
614       ContractX3d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
615       ContractY3d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
616       ContractZ3d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
617       writeQuads3d(elem, tidx, tidy, comp, 0, nelem, r_t, d_V);
618     } else {
619       readQuads3d(elem, tidx, tidy, comp, 0, nelem, d_U, r_V);
620       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
621       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
622       ContractTransposeX3d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
623       writeDofs3d(elem, tidx, tidy, comp, nelem, r_t, d_V);
624     }
625   }
626 }
627 
628 //------------------------------------------------------------------------------
629 // 3D derivatives at quadrature points
630 //------------------------------------------------------------------------------
631 inline __device__ void grad3d(const CeedInt nelem, const int transpose,
632                               const CeedScalar *s_B, const CeedScalar *s_G,
633                               const CeedScalar *__restrict__ d_U,
634                               CeedScalar *__restrict__ d_V,
635                               CeedScalar *slice) {
636   // Use P1D for one of these
637   CeedScalar r_U[T1D];
638   CeedScalar r_V[T1D];
639   CeedScalar r_t[T1D];
640 
641   const int tidx = threadIdx.x;
642   const int tidy = threadIdx.y;
643   const int tidz = threadIdx.z;
644   const int blockElem = tidz/BASIS_NCOMP;
645   const int elemsPerBlock = blockDim.z/BASIS_NCOMP;
646   const int comp = tidz%BASIS_NCOMP;
647   int dim;
648 
649   for (CeedInt elem = blockIdx.x*elemsPerBlock + blockElem; elem < nelem;
650        elem += gridDim.x*elemsPerBlock) {
651     for (int i = 0; i < T1D; ++i) {
652       r_U[i] = 0.0;
653       r_V[i] = 0.0;
654       r_t[i] = 0.0;
655     }
656     if (!transpose) {
657       readDofs3d(elem, tidx, tidy, comp, nelem, d_U, r_U);
658       ContractX3d(slice, tidx, tidy, tidz, r_U, s_G, r_V);
659       ContractY3d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
660       ContractZ3d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
661       dim = 0;
662       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
663       ContractX3d(slice, tidx, tidy, tidz, r_U, s_B, r_V);
664       ContractY3d(slice, tidx, tidy, tidz, r_V, s_G, r_t);
665       ContractZ3d(slice, tidx, tidy, tidz, r_t, s_B, r_V);
666       dim = 1;
667       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
668       ContractX3d(slice, tidx, tidy, tidz, r_U, s_B, r_V);
669       ContractY3d(slice, tidx, tidy, tidz, r_V, s_B, r_t);
670       ContractZ3d(slice, tidx, tidy, tidz, r_t, s_G, r_V);
671       dim = 2;
672       writeQuads3d(elem, tidx, tidy, comp, dim, nelem, r_V, d_V);
673     } else {
674       dim = 0;
675       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
676       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, s_B, r_t);
677       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_B, r_U);
678       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, s_G, r_V);
679       dim = 1;
680       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
681       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, s_B, r_t);
682       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_G, r_U);
683       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, s_B, r_t);
684       add(r_V, r_t);
685       dim = 2;
686       readQuads3d(elem, tidx, tidy, comp, dim, nelem, d_U, r_U);
687       ContractTransposeZ3d(slice, tidx, tidy, tidz, r_U, s_G, r_t);
688       ContractTransposeY3d(slice, tidx, tidy, tidz, r_t, s_B, r_U);
689       ContractTransposeX3d(slice, tidx, tidy, tidz, r_U, s_B, r_t);
690       add(r_V, r_t);
691       writeDofs3d(elem, tidx, tidy, comp, nelem, r_V, d_V);
692     }
693   }
694 }
695 
696 //------------------------------------------------------------------------------
697 // 3D quadrature weights
698 //------------------------------------------------------------------------------
699 __device__ void weight3d(const CeedInt nelem, const CeedScalar *qweight1d,
700                          CeedScalar *w) {
701   const int i = threadIdx.x;
702   const int j = threadIdx.y;
703   const int k = threadIdx.z;
704   const CeedScalar weight = qweight1d[i]*qweight1d[j]*qweight1d[k];
705   for (int e = blockIdx.x; e < nelem; e += gridDim.x) {
706     const int ind = e*Q1D*Q1D*Q1D + i + j*Q1D + k*Q1D*Q1D;
707     w[ind] = weight;
708   }
709 }
710 
711 
712 //------------------------------------------------------------------------------
713 // Basis kernels
714 //------------------------------------------------------------------------------
715 
716 //------------------------------------------------------------------------------
717 // Interp kernel by dim
718 //------------------------------------------------------------------------------
719 extern "C" __launch_bounds__(INTERP_BLKSIZE) __global__ void interp(
720                                   const CeedInt nelem, const int transpose,
721                                   CeedScalar *d_interp1d,
722                                   const CeedScalar *__restrict__ d_U,
723                                   CeedScalar *__restrict__ d_V) {
724 
725   HIP_DYNAMIC_SHARED( double, slice)
726   // load interp1d into shared memory
727   __shared__ double s_B[P1D*Q1D];
728   loadMatrix(d_interp1d, s_B);
729   __syncthreads();
730 
731   if (BASIS_DIM == 1) {
732     interp1d(nelem, transpose, s_B, d_U, d_V, slice);
733   } else if (BASIS_DIM == 2) {
734     interp2d(nelem, transpose, s_B, d_U, d_V, slice);
735   } else if (BASIS_DIM == 3) {
736     interp3d(nelem, transpose, s_B, d_U, d_V, slice);
737   }
738 }
739 
740 //------------------------------------------------------------------------------
741 // Grad kernel by dim
742 //------------------------------------------------------------------------------
743 extern "C" __launch_bounds__(GRAD_BLKSIZE) __global__ void grad(const CeedInt nelem,
744                                 const int transpose,
745                                 CeedScalar *d_interp1d, CeedScalar *d_grad1d,
746                                 const CeedScalar *__restrict__ d_U,
747                                 CeedScalar *__restrict__ d_V) {
748   HIP_DYNAMIC_SHARED( double, slice)
749   // load interp1d and grad1d into shared memory
750   __shared__ double s_B[P1D*Q1D];
751   loadMatrix(d_interp1d, s_B);
752   __shared__ double s_G[P1D*Q1D];
753   loadMatrix(d_grad1d, s_G);
754   __syncthreads();
755 
756   if (BASIS_DIM == 1) {
757     grad1d(nelem, transpose, s_B, s_G, d_U, d_V, slice);
758   } else if (BASIS_DIM == 2) {
759     grad2d(nelem, transpose, s_B, s_G, d_U, d_V, slice);
760   } else if (BASIS_DIM == 3) {
761     grad3d(nelem, transpose, s_B, s_G, d_U, d_V, slice);
762   }
763 }
764 
765 //------------------------------------------------------------------------------
766 // Weight kernels by dim
767 //------------------------------------------------------------------------------
768 extern "C" __launch_bounds__(WEIGHT_BLKSIZE) __global__ void weight(const CeedInt nelem,
769                                   const CeedScalar *__restrict__ qweight1d,
770                                   CeedScalar *__restrict__ v) {
771   if (BASIS_DIM == 1) {
772     weight1d(nelem, qweight1d, v);
773   } else if (BASIS_DIM == 2) {
774     weight2d(nelem, qweight1d, v);
775   } else if (BASIS_DIM == 3) {
776     weight3d(nelem, qweight1d, v);
777   }
778 }
779 
780 );
781 // *INDENT-ON*
782 
783 //------------------------------------------------------------------------------
784 // Compute a block size based on required minimum threads
785 //------------------------------------------------------------------------------
786 static CeedInt ComputeBlockSizeFromRequirement(const CeedInt required) {
787   CeedInt maxSize = 1024;    // Max total threads per block
788   CeedInt currentSize = 64;  // Start with one group
789 
790   while(currentSize < maxSize) {
791     if (currentSize > required)
792       break;
793     else
794       currentSize = currentSize * 2;
795   }
796   return currentSize;
797 }
798 
799 //------------------------------------------------------------------------------
800 // Compute required thread block sizes for basis kernels given P, Q, dim, and
801 // ncomp
802 //------------------------------------------------------------------------------
803 static int ComputeBasisThreadBlockSizes(const CeedInt dim, const CeedInt P1d,
804                                         const CeedInt Q1d,
805                                         const CeedInt ncomp, CeedInt *blksizes) {
806 
807   // Note that this will use the same block sizes for all dimensions when compiling,
808   // but as each basis object is defined for a particular dimension, we will never
809   // call any kernels except the ones for the dimension for which we have computed the
810   // block sizes.
811   const CeedInt thread1d = CeedIntMax(P1d, Q1d);
812   switch (dim) {
813   case 1: {
814     // Interp kernels:
815     blksizes[0] = 256;
816 
817     // Grad kernels:
818     blksizes[1] = 256;
819 
820     // Weight kernels:
821     blksizes[2] = 256;
822 
823   } break;
824   case 2: {
825     // Interp kernels:
826     CeedInt required = thread1d * thread1d * ncomp;
827     blksizes[0]  = ComputeBlockSizeFromRequirement(required);
828 
829     // Grad kernels: currently use same required minimum threads
830     blksizes[1]  = ComputeBlockSizeFromRequirement(required);
831 
832     // Weight kernels:
833     required = CeedIntMax(64, Q1d * Q1d);
834     blksizes[2]  = ComputeBlockSizeFromRequirement(required);
835 
836   } break;
837   case 3: {
838     // Interp kernels:
839     CeedInt required = thread1d * thread1d * ncomp;
840     blksizes[0]  = ComputeBlockSizeFromRequirement(required);
841 
842     // Grad kernels: currently use same required minimum threads
843     blksizes[1]  = ComputeBlockSizeFromRequirement(required);
844 
845     // Weight kernels:
846     required = Q1d * Q1d * Q1d;
847     blksizes[2]  = ComputeBlockSizeFromRequirement(required);
848   }
849   }
850 
851   return CEED_ERROR_SUCCESS;
852 }
853 
854 //------------------------------------------------------------------------------
855 // Apply basis
856 //------------------------------------------------------------------------------
857 int CeedBasisApplyTensor_Hip_shared(CeedBasis basis, const CeedInt nelem,
858                                     CeedTransposeMode tmode,
859                                     CeedEvalMode emode, CeedVector u,
860                                     CeedVector v) {
861   int ierr;
862   Ceed ceed;
863   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
864   Ceed_Hip_shared *ceed_Hip;
865   CeedGetData(ceed, &ceed_Hip); CeedChkBackend(ierr);
866   CeedBasis_Hip_shared *data;
867   CeedBasisGetData(basis, &data); CeedChkBackend(ierr);
868   const CeedInt transpose = tmode == CEED_TRANSPOSE;
869   CeedInt dim, ncomp;
870   ierr = CeedBasisGetDimension(basis, &dim); CeedChkBackend(ierr);
871   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr);
872 
873   // Read vectors
874   const CeedScalar *d_u;
875   CeedScalar *d_v;
876   if (emode != CEED_EVAL_WEIGHT) {
877     ierr = CeedVectorGetArrayRead(u, CEED_MEM_DEVICE, &d_u); CeedChkBackend(ierr);
878   }
879   ierr = CeedVectorGetArray(v, CEED_MEM_DEVICE, &d_v); CeedChkBackend(ierr);
880 
881   // Clear v for transpose mode
882   if (tmode == CEED_TRANSPOSE) {
883     CeedInt length;
884     ierr = CeedVectorGetLength(v, &length); CeedChkBackend(ierr);
885     ierr = hipMemset(d_v, 0, length * sizeof(CeedScalar)); CeedChkBackend(ierr);
886   }
887 
888   // Apply basis operation
889   switch (emode) {
890   case CEED_EVAL_INTERP: {
891     CeedInt P1d, Q1d;
892     CeedInt blksize = data->blksizes[0];
893     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr);
894     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr);
895     CeedInt thread1d = CeedIntMax(Q1d, P1d);
896     void *interpargs[] = {(void *) &nelem, (void *) &transpose, &data->d_interp1d,
897                           &d_u, &d_v
898                          };
899     if (dim == 1) {
900       CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64;
901       elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1;
902       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
903                                              ? 1 : 0 );
904       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
905       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, 1,
906                                        elemsPerBlock, sharedMem,
907                                        interpargs); CeedChkBackend(ierr);
908     } else if (dim == 2) {
909       // Check if required threads is small enough to do multiple elems
910       const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1);
911       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
912                                              ? 1 : 0 );
913       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
914       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d,
915                                        ncomp*elemsPerBlock, sharedMem,
916                                        interpargs); CeedChkBackend(ierr);
917     } else if (dim == 3) {
918       CeedInt elemsPerBlock = 1;
919       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
920                                              ? 1 : 0 );
921       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
922       ierr = CeedRunKernelDimSharedHip(ceed, data->interp, grid, thread1d, thread1d,
923                                        ncomp*elemsPerBlock, sharedMem,
924                                        interpargs); CeedChkBackend(ierr);
925     }
926   } break;
927   case CEED_EVAL_GRAD: {
928     CeedInt P1d, Q1d;
929     CeedInt blksize = data->blksizes[1];
930     ierr = CeedBasisGetNumNodes1D(basis, &P1d); CeedChkBackend(ierr);
931     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr);
932     CeedInt thread1d = CeedIntMax(Q1d, P1d);
933     void *gradargs[] = {(void *) &nelem, (void *) &transpose, &data->d_interp1d,
934                         &data->d_grad1d, &d_u, &d_v
935                        };
936     if (dim == 1) {
937       CeedInt elemsPerBlock = 64*thread1d > 256? 256/thread1d : 64;
938       elemsPerBlock = elemsPerBlock>0?elemsPerBlock:1;
939       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
940                                              ? 1 : 0 );
941       CeedInt sharedMem = elemsPerBlock*thread1d*sizeof(CeedScalar);
942       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, 1,
943                                        elemsPerBlock, sharedMem, gradargs);
944       CeedChkBackend(ierr);
945     } else if (dim == 2) {
946       // Check if required threads is small enough to do multiple elems
947       const CeedInt elemsPerBlock = CeedIntMax(blksize/(thread1d*thread1d*ncomp), 1);
948       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
949                                              ? 1 : 0 );
950       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
951       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d,
952                                        ncomp*elemsPerBlock, sharedMem,
953                                        gradargs); CeedChkBackend(ierr);
954     } else if (dim == 3) {
955       CeedInt elemsPerBlock = 1;
956       CeedInt grid = nelem/elemsPerBlock + ( (nelem/elemsPerBlock*elemsPerBlock<nelem)
957                                              ? 1 : 0 );
958       CeedInt sharedMem = ncomp*elemsPerBlock*thread1d*thread1d*sizeof(CeedScalar);
959       ierr = CeedRunKernelDimSharedHip(ceed, data->grad, grid, thread1d, thread1d,
960                                        ncomp*elemsPerBlock, sharedMem,
961                                        gradargs); CeedChkBackend(ierr);
962     }
963   } break;
964   case CEED_EVAL_WEIGHT: {
965     CeedInt Q1d;
966     CeedInt blksize = data->blksizes[2];
967     ierr = CeedBasisGetNumQuadraturePoints1D(basis, &Q1d); CeedChkBackend(ierr);
968     void *weightargs[] = {(void *) &nelem, (void *) &data->d_qweight1d, &d_v};
969     if (dim == 1) {
970       const CeedInt optElems = blksize/Q1d;
971       const CeedInt elemsPerBlock = optElems>0?optElems:1;
972       const CeedInt gridsize = nelem/elemsPerBlock + ( (
973                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
974       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d,
975                                  elemsPerBlock, 1, weightargs);
976       CeedChkBackend(ierr);
977     } else if (dim == 2) {
978       const CeedInt optElems = blksize/(Q1d*Q1d);
979       const CeedInt elemsPerBlock = optElems>0?optElems:1;
980       const CeedInt gridsize = nelem/elemsPerBlock + ( (
981                                  nelem/elemsPerBlock*elemsPerBlock<nelem)? 1 : 0 );
982       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d,
983                                  elemsPerBlock, weightargs);
984       CeedChkBackend(ierr);
985     } else if (dim == 3) {
986       const CeedInt gridsize = nelem;
987       ierr = CeedRunKernelDimHip(ceed, data->weight, gridsize, Q1d, Q1d, Q1d,
988                                  weightargs);
989       CeedChkBackend(ierr);
990     }
991   } break;
992   // LCOV_EXCL_START
993   // Evaluate the divergence to/from the quadrature points
994   case CEED_EVAL_DIV:
995     return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_DIV not supported");
996   // Evaluate the curl to/from the quadrature points
997   case CEED_EVAL_CURL:
998     return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_CURL not supported");
999   // Take no action, BasisApply should not have been called
1000   case CEED_EVAL_NONE:
1001     return CeedError(ceed, CEED_ERROR_BACKEND,
1002                      "CEED_EVAL_NONE does not make sense in this context");
1003     // LCOV_EXCL_STOP
1004   }
1005 
1006   // Restore vectors
1007   if (emode != CEED_EVAL_WEIGHT) {
1008     ierr = CeedVectorRestoreArrayRead(u, &d_u); CeedChkBackend(ierr);
1009   }
1010   ierr = CeedVectorRestoreArray(v, &d_v); CeedChkBackend(ierr);
1011   return CEED_ERROR_SUCCESS;
1012 }
1013 
1014 //------------------------------------------------------------------------------
1015 // Destroy basis
1016 //------------------------------------------------------------------------------
1017 static int CeedBasisDestroy_Hip_shared(CeedBasis basis) {
1018   int ierr;
1019   Ceed ceed;
1020   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
1021 
1022   CeedBasis_Hip_shared *data;
1023   ierr = CeedBasisGetData(basis, &data); CeedChkBackend(ierr);
1024 
1025   CeedChk_Hip(ceed, hipModuleUnload(data->module));
1026 
1027   ierr = hipFree(data->d_qweight1d); CeedChk_Hip(ceed, ierr);
1028   ierr = hipFree(data->d_interp1d); CeedChk_Hip(ceed, ierr);
1029   ierr = hipFree(data->d_grad1d); CeedChk_Hip(ceed, ierr);
1030   ierr = hipFree(data->d_collograd1d); CeedChk_Hip(ceed, ierr);
1031 
1032   ierr = CeedFree(&data); CeedChkBackend(ierr);
1033 
1034   return CEED_ERROR_SUCCESS;
1035 }
1036 
1037 //------------------------------------------------------------------------------
1038 // Create tensor basis
1039 //------------------------------------------------------------------------------
1040 int CeedBasisCreateTensorH1_Hip_shared(CeedInt dim, CeedInt P1d, CeedInt Q1d,
1041                                        const CeedScalar *interp1d,
1042                                        const CeedScalar *grad1d,
1043                                        const CeedScalar *qref1d,
1044                                        const CeedScalar *qweight1d,
1045                                        CeedBasis basis) {
1046   int ierr;
1047   Ceed ceed;
1048   ierr = CeedBasisGetCeed(basis, &ceed); CeedChkBackend(ierr);
1049   CeedBasis_Hip_shared *data;
1050   ierr = CeedCalloc(1, &data); CeedChkBackend(ierr);
1051 
1052   // Copy basis data to GPU
1053   const CeedInt qBytes = Q1d * sizeof(CeedScalar);
1054   ierr = hipMalloc((void **)&data->d_qweight1d, qBytes); CeedChk_Hip(ceed, ierr);
1055   ierr = hipMemcpy(data->d_qweight1d, qweight1d, qBytes,
1056                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1057 
1058   const CeedInt iBytes = qBytes * P1d;
1059   ierr = hipMalloc((void **)&data->d_interp1d, iBytes); CeedChk_Hip(ceed, ierr);
1060   ierr = hipMemcpy(data->d_interp1d, interp1d, iBytes,
1061                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1062 
1063   ierr = hipMalloc((void **)&data->d_grad1d, iBytes); CeedChk_Hip(ceed, ierr);
1064   ierr = hipMemcpy(data->d_grad1d, grad1d, iBytes,
1065                    hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1066 
1067   // Compute collocated gradient and copy to GPU
1068   data->d_collograd1d = NULL;
1069   if (dim == 3 && Q1d >= P1d) {
1070     CeedScalar *collograd1d;
1071     ierr = CeedMalloc(Q1d*Q1d, &collograd1d); CeedChkBackend(ierr);
1072     ierr = CeedBasisGetCollocatedGrad(basis, collograd1d); CeedChkBackend(ierr);
1073     ierr = hipMalloc((void **)&data->d_collograd1d, qBytes * Q1d);
1074     CeedChk_Hip(ceed, ierr);
1075     ierr = hipMemcpy(data->d_collograd1d, collograd1d, qBytes * Q1d,
1076                      hipMemcpyHostToDevice); CeedChk_Hip(ceed, ierr);
1077     ierr = CeedFree(&collograd1d); CeedChkBackend(ierr);
1078   }
1079 
1080   // Set number of threads per block for basis kernels
1081   CeedInt ncomp;
1082   ierr = CeedBasisGetNumComponents(basis, &ncomp); CeedChkBackend(ierr);
1083   ierr = ComputeBasisThreadBlockSizes(dim, P1d, Q1d, ncomp, data->blksizes);
1084   CeedChkBackend(ierr);
1085 
1086   // Compile basis kernels
1087   ierr = CeedCompileHip(ceed, kernelsShared, &data->module, 11,
1088                         "Q1D", Q1d,
1089                         "P1D", P1d,
1090                         "T1D", CeedIntMax(Q1d, P1d),
1091                         "BASIS_BUF_LEN", ncomp * CeedIntPow(Q1d > P1d ?
1092                             Q1d : P1d, dim),
1093                         "BASIS_DIM", dim,
1094                         "BASIS_NCOMP", ncomp,
1095                         "BASIS_ELEMSIZE", CeedIntPow(P1d, dim),
1096                         "BASIS_NQPT", CeedIntPow(Q1d, dim),
1097                         "INTERP_BLKSIZE", data->blksizes[0],
1098                         "GRAD_BLKSIZE", data->blksizes[1],
1099                         "WEIGHT_BLKSIZE", data->blksizes[2]
1100                        ); CeedChkBackend(ierr);
1101   ierr = CeedGetKernelHip(ceed, data->module, "interp", &data->interp);
1102   CeedChkBackend(ierr);
1103   ierr = CeedGetKernelHip(ceed, data->module, "grad", &data->grad);
1104   CeedChkBackend(ierr);
1105   ierr = CeedGetKernelHip(ceed, data->module, "weight", &data->weight);
1106   CeedChkBackend(ierr);
1107 
1108   ierr = CeedBasisSetData(basis, data); CeedChkBackend(ierr);
1109 
1110   // Register backend functions
1111   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Apply",
1112                                 CeedBasisApplyTensor_Hip_shared);
1113   CeedChkBackend(ierr);
1114   ierr = CeedSetBackendFunction(ceed, "Basis", basis, "Destroy",
1115                                 CeedBasisDestroy_Hip_shared); CeedChkBackend(ierr);
1116   return CEED_ERROR_SUCCESS;
1117 }
1118 //------------------------------------------------------------------------------
1119