xref: /libCEED/include/ceed/jit-source/hip/hip-shared-basis-tensor.h (revision 26ef7cdab87216015dfdd84a70b0ca05f3e531f1)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for HIP shared memory tensor product basis
10 #include <ceed/types.h>
11 
12 #include "hip-shared-basis-read-write-templates.h"
13 #include "hip-shared-basis-tensor-templates.h"
14 
15 //------------------------------------------------------------------------------
16 // Interp kernel by dim
17 //------------------------------------------------------------------------------
18 extern "C" __launch_bounds__(BASIS_INTERP_BLOCK_SIZE) __global__
19     void Interp(const CeedInt num_elem, const CeedScalar *d_interp_1d, const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) {
20   extern __shared__ CeedScalar slice[];
21 
22   // load interp_1d into shared memory
23   __shared__ CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D];
24   loadMatrix<BASIS_P_1D * BASIS_Q_1D>(d_interp_1d, s_B);
25   __syncthreads();
26 
27   SharedData_Hip data;
28   data.t_id_x = threadIdx.x;
29   data.t_id_y = threadIdx.y;
30   data.t_id_z = threadIdx.z;
31   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
32   data.slice  = slice + data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
33 
34   CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
35   CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
36 
37   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
38     if (BASIS_DIM == 1) {
39       ReadElementStrided1d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * num_elem, BASIS_P_1D, d_U, r_U);
40       Interp1d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, s_B, r_V);
41       WriteElementStrided1d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * num_elem, BASIS_Q_1D, r_V, d_V);
42     } else if (BASIS_DIM == 2) {
43       ReadElementStrided2d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * num_elem, BASIS_P_1D * BASIS_P_1D, d_U, r_U);
44       InterpTensor2d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, s_B, r_V);
45       WriteElementStrided2d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * num_elem, BASIS_Q_1D * BASIS_Q_1D, r_V, d_V);
46     } else if (BASIS_DIM == 3) {
47       ReadElementStrided3d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * BASIS_P_1D * num_elem,
48                                                        BASIS_P_1D * BASIS_P_1D * BASIS_P_1D, d_U, r_U);
49       InterpTensor3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, s_B, r_V);
50       WriteElementStrided3d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D * num_elem,
51                                                         BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D, r_V, d_V);
52     }
53   }
54 }
55 
56 extern "C" __launch_bounds__(BASIS_INTERP_BLOCK_SIZE) __global__
57     void InterpTranspose(const CeedInt num_elem, const CeedScalar *d_interp_1d, const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) {
58   extern __shared__ CeedScalar slice[];
59 
60   // load interp_1d into shared memory
61   __shared__ CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D];
62   loadMatrix<BASIS_P_1D * BASIS_Q_1D>(d_interp_1d, s_B);
63   __syncthreads();
64 
65   SharedData_Hip data;
66   data.t_id_x = threadIdx.x;
67   data.t_id_y = threadIdx.y;
68   data.t_id_z = threadIdx.z;
69   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
70   data.slice  = slice + data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
71 
72   CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
73   CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
74 
75   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
76     if (BASIS_DIM == 1) {
77       ReadElementStrided1d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * num_elem, BASIS_Q_1D, d_U, r_U);
78       InterpTranspose1d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, s_B, r_V);
79       WriteElementStrided1d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * num_elem, BASIS_P_1D, r_V, d_V);
80     } else if (BASIS_DIM == 2) {
81       ReadElementStrided2d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * num_elem, BASIS_Q_1D * BASIS_Q_1D, d_U, r_U);
82       InterpTransposeTensor2d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, s_B, r_V);
83       WriteElementStrided2d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * num_elem, BASIS_P_1D * BASIS_P_1D, r_V, d_V);
84     } else if (BASIS_DIM == 3) {
85       ReadElementStrided3d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D * num_elem,
86                                                        BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D, d_U, r_U);
87       InterpTransposeTensor3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, s_B, r_V);
88       WriteElementStrided3d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * BASIS_P_1D * num_elem,
89                                                         BASIS_P_1D * BASIS_P_1D * BASIS_P_1D, r_V, d_V);
90     }
91   }
92 }
93 
94 extern "C" __launch_bounds__(BASIS_INTERP_BLOCK_SIZE) __global__
95     void InterpTransposeAdd(const CeedInt num_elem, const CeedScalar *d_interp_1d, const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) {
96   extern __shared__ CeedScalar slice[];
97 
98   // load interp_1d into shared memory
99   __shared__ CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D];
100   loadMatrix<BASIS_P_1D * BASIS_Q_1D>(d_interp_1d, s_B);
101   __syncthreads();
102 
103   SharedData_Hip data;
104   data.t_id_x = threadIdx.x;
105   data.t_id_y = threadIdx.y;
106   data.t_id_z = threadIdx.z;
107   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
108   data.slice  = &slice[data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1)];
109 
110   CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
111   CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
112 
113   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
114     if (BASIS_DIM == 1) {
115       ReadElementStrided1d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * num_elem, BASIS_Q_1D, d_U, r_U);
116       InterpTranspose1d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, s_B, r_V);
117       SumElementStrided1d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * num_elem, BASIS_P_1D, r_V, d_V);
118     } else if (BASIS_DIM == 2) {
119       ReadElementStrided2d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * num_elem, BASIS_Q_1D * BASIS_Q_1D, d_U, r_U);
120       InterpTransposeTensor2d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, s_B, r_V);
121       SumElementStrided2d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * num_elem, BASIS_P_1D * BASIS_P_1D, r_V, d_V);
122     } else if (BASIS_DIM == 3) {
123       ReadElementStrided3d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D * num_elem,
124                                                        BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D, d_U, r_U);
125       InterpTransposeTensor3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, s_B, r_V);
126       SumElementStrided3d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * BASIS_P_1D * num_elem,
127                                                       BASIS_P_1D * BASIS_P_1D * BASIS_P_1D, r_V, d_V);
128     }
129   }
130 }
131 
132 //------------------------------------------------------------------------------
133 // Grad kernel by dim
134 //------------------------------------------------------------------------------
135 extern "C" __launch_bounds__(BASIS_GRAD_BLOCK_SIZE) __global__
136     void Grad(const CeedInt num_elem, const CeedScalar *d_interp_1d, const CeedScalar *d_grad_1d, const CeedScalar *__restrict__ d_U,
137               CeedScalar *__restrict__ d_V) {
138   extern __shared__ CeedScalar slice[];
139 
140   // load interp_1d and grad_1d into shared memory
141   __shared__ CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D];
142   loadMatrix<BASIS_P_1D * BASIS_Q_1D>(d_interp_1d, s_B);
143   __shared__ CeedScalar s_G[BASIS_Q_1D * (BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D)];
144   loadMatrix<BASIS_Q_1D *(BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D)>(d_grad_1d, s_G);
145   __syncthreads();
146 
147   SharedData_Hip data;
148   data.t_id_x = threadIdx.x;
149   data.t_id_y = threadIdx.y;
150   data.t_id_z = threadIdx.z;
151   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
152   data.slice  = slice + data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
153 
154   CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
155   CeedScalar r_V[BASIS_NUM_COMP * BASIS_DIM * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
156 
157   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
158     if (BASIS_DIM == 1) {
159       ReadElementStrided1d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * num_elem, BASIS_P_1D, d_U, r_U);
160       Grad1d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, s_B, s_G, r_V);
161       WriteElementStrided1d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * num_elem, BASIS_Q_1D, r_V, d_V);
162     } else if (BASIS_DIM == 2) {
163       ReadElementStrided2d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * num_elem, BASIS_P_1D * BASIS_P_1D, d_U, r_U);
164       GradTensor2d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, s_B, s_G, r_V);
165       WriteElementStrided2d<BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * num_elem, BASIS_Q_1D * BASIS_Q_1D, r_V,
166                                                                     d_V);
167     } else if (BASIS_DIM == 3) {
168       ReadElementStrided3d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * BASIS_P_1D * num_elem,
169                                                        BASIS_P_1D * BASIS_P_1D * BASIS_P_1D, d_U, r_U);
170       if (BASIS_HAS_COLLOCATED_GRAD) GradTensorCollocated3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, s_B, s_G, r_V);
171       else GradTensor3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, s_B, s_G, r_V);
172       WriteElementStrided3d<BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D * num_elem,
173                                                                     BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D, r_V, d_V);
174     }
175   }
176 }
177 
178 extern "C" __launch_bounds__(BASIS_GRAD_BLOCK_SIZE) __global__
179     void GradTranspose(const CeedInt num_elem, const CeedScalar *d_interp_1d, const CeedScalar *d_grad_1d, const CeedScalar *__restrict__ d_U,
180                        CeedScalar *__restrict__ d_V) {
181   extern __shared__ CeedScalar slice[];
182 
183   // load interp_1d and grad_1d into shared memory
184   __shared__ CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D];
185   loadMatrix<BASIS_P_1D * BASIS_Q_1D>(d_interp_1d, s_B);
186   __shared__ CeedScalar s_G[BASIS_Q_1D * (BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D)];
187   loadMatrix<BASIS_Q_1D *(BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D)>(d_grad_1d, s_G);
188   __syncthreads();
189 
190   SharedData_Hip data;
191   data.t_id_x = threadIdx.x;
192   data.t_id_y = threadIdx.y;
193   data.t_id_z = threadIdx.z;
194   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
195   data.slice  = slice + data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
196 
197   CeedScalar r_U[BASIS_NUM_COMP * BASIS_DIM * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
198   CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
199 
200   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
201     if (BASIS_DIM == 1) {
202       ReadElementStrided1d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * num_elem, BASIS_Q_1D, d_U, r_U);
203       GradTranspose1d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, s_B, s_G, r_V);
204       WriteElementStrided1d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * num_elem, BASIS_P_1D, r_V, d_V);
205     } else if (BASIS_DIM == 2) {
206       ReadElementStrided2d<BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * num_elem, BASIS_Q_1D * BASIS_Q_1D, d_U,
207                                                                    r_U);
208       GradTransposeTensor2d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, s_B, s_G, r_V);
209       WriteElementStrided2d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * num_elem, BASIS_P_1D * BASIS_P_1D, r_V, d_V);
210     } else if (BASIS_DIM == 3) {
211       ReadElementStrided3d<BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D * num_elem,
212                                                                    BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D, d_U, r_U);
213       if (BASIS_HAS_COLLOCATED_GRAD) GradTransposeTensorCollocated3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, s_B, s_G, r_V);
214       else GradTransposeTensor3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, s_B, s_G, r_V);
215       WriteElementStrided3d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * BASIS_P_1D * num_elem,
216                                                         BASIS_P_1D * BASIS_P_1D * BASIS_P_1D, r_V, d_V);
217     }
218   }
219 }
220 
221 extern "C" __launch_bounds__(BASIS_GRAD_BLOCK_SIZE) __global__
222     void GradTransposeAdd(const CeedInt num_elem, const CeedScalar *d_interp_1d, const CeedScalar *d_grad_1d, const CeedScalar *__restrict__ d_U,
223                           CeedScalar *__restrict__ d_V) {
224   extern __shared__ CeedScalar slice[];
225 
226   // load interp_1d and grad_1d into shared memory
227   __shared__ CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D];
228   loadMatrix<BASIS_P_1D * BASIS_Q_1D>(d_interp_1d, s_B);
229   __shared__ CeedScalar s_G[BASIS_Q_1D * (BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D)];
230   loadMatrix<BASIS_Q_1D *(BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D)>(d_grad_1d, s_G);
231   __syncthreads();
232 
233   SharedData_Hip data;
234   data.t_id_x = threadIdx.x;
235   data.t_id_y = threadIdx.y;
236   data.t_id_z = threadIdx.z;
237   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
238   data.slice  = &slice[data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1)];
239 
240   CeedScalar r_U[BASIS_NUM_COMP * BASIS_DIM * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
241   CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
242 
243   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
244     if (BASIS_DIM == 1) {
245       ReadElementStrided1d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * num_elem, BASIS_Q_1D, d_U, r_U);
246       GradTranspose1d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, s_B, s_G, r_V);
247       SumElementStrided1d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * num_elem, BASIS_P_1D, r_V, d_V);
248     } else if (BASIS_DIM == 2) {
249       ReadElementStrided2d<BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * num_elem, BASIS_Q_1D * BASIS_Q_1D, d_U,
250                                                                    r_U);
251       GradTransposeTensor2d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, s_B, s_G, r_V);
252       SumElementStrided2d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * num_elem, BASIS_P_1D * BASIS_P_1D, r_V, d_V);
253     } else if (BASIS_DIM == 3) {
254       ReadElementStrided3d<BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D * num_elem,
255                                                                    BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D, d_U, r_U);
256       if (BASIS_HAS_COLLOCATED_GRAD) GradTransposeTensorCollocated3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, s_B, s_G, r_V);
257       else GradTransposeTensor3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, s_B, s_G, r_V);
258       SumElementStrided3d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * BASIS_P_1D * num_elem,
259                                                       BASIS_P_1D * BASIS_P_1D * BASIS_P_1D, r_V, d_V);
260     }
261   }
262 }
263 
264 //------------------------------------------------------------------------------
265 // Weight kernels by dim
266 //------------------------------------------------------------------------------
267 extern "C" __launch_bounds__(BASIS_WEIGHT_BLOCK_SIZE) __global__
268     void Weight(const CeedInt num_elem, const CeedScalar *__restrict__ q_weight_1d, CeedScalar *__restrict__ d_W) {
269   extern __shared__ CeedScalar slice[];
270 
271   SharedData_Hip data;
272   data.t_id_x = threadIdx.x;
273   data.t_id_y = threadIdx.y;
274   data.t_id_z = threadIdx.z;
275   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
276   data.slice  = slice + data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
277 
278   CeedScalar r_W[BASIS_DIM > 2 ? BASIS_Q_1D : 1];
279 
280   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
281     if (BASIS_DIM == 1) {
282       Weight1d<BASIS_Q_1D>(data, q_weight_1d, r_W);
283       WriteElementStrided1d<1, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * num_elem, BASIS_Q_1D, r_W, d_W);
284     } else if (BASIS_DIM == 2) {
285       WeightTensor2d<BASIS_Q_1D>(data, q_weight_1d, r_W);
286       WriteElementStrided2d<1, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * num_elem, BASIS_Q_1D * BASIS_Q_1D, r_W, d_W);
287     } else if (BASIS_DIM == 3) {
288       WeightTensor3d<BASIS_Q_1D>(data, q_weight_1d, r_W);
289       WriteElementStrided3d<1, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D * num_elem, BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D, r_W,
290                                            d_W);
291     }
292   }
293 }
294