xref: /libCEED/include/ceed/jit-source/hip/hip-shared-basis-read-write-templates.h (revision d4cc18453651bd0f94c1a2e078b2646a92dafdcc)
1*9ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
29e201c85SYohann // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
39e201c85SYohann //
49e201c85SYohann // SPDX-License-Identifier: BSD-2-Clause
59e201c85SYohann //
69e201c85SYohann // This file is part of CEED:  http://github.com/ceed
79e201c85SYohann 
89e201c85SYohann /// @file
99e201c85SYohann /// Internal header for HIP shared memory basis read/write templates
10c0b5abf0SJeremy L Thompson #include <ceed/types.h>
119e201c85SYohann 
129e201c85SYohann //------------------------------------------------------------------------------
139e201c85SYohann // Helper function: load matrices for basis actions
149e201c85SYohann //------------------------------------------------------------------------------
15aa4002adSJeremy L Thompson template <int P, int Q>
LoadMatrix(SharedData_Hip & data,const CeedScalar * __restrict__ d_B,CeedScalar * B)16aa4002adSJeremy L Thompson inline __device__ void LoadMatrix(SharedData_Hip &data, const CeedScalar *__restrict__ d_B, CeedScalar *B) {
17aa4002adSJeremy L Thompson   for (CeedInt i = data.t_id; i < P * Q; i += blockDim.x * blockDim.y * blockDim.z) B[i] = d_B[i];
189e201c85SYohann }
199e201c85SYohann 
209e201c85SYohann //------------------------------------------------------------------------------
219e201c85SYohann // 1D
229e201c85SYohann //------------------------------------------------------------------------------
239e201c85SYohann 
249e201c85SYohann //------------------------------------------------------------------------------
259e201c85SYohann // E-vector -> single element
269e201c85SYohann //------------------------------------------------------------------------------
279e201c85SYohann template <int NUM_COMP, int P_1D>
ReadElementStrided1d(SharedData_Hip & data,const CeedInt elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,const CeedScalar * __restrict__ d_u,CeedScalar * r_u)282b730f8bSJeremy L Thompson inline __device__ void ReadElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
292b730f8bSJeremy L Thompson                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
309e201c85SYohann   if (data.t_id_x < P_1D) {
319e201c85SYohann     const CeedInt node = data.t_id_x;
329e201c85SYohann     const CeedInt ind  = node * strides_node + elem * strides_elem;
33672b0f2aSSebastian Grimberg 
349e201c85SYohann     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
359e201c85SYohann       r_u[comp] = d_u[ind + comp * strides_comp];
369e201c85SYohann     }
379e201c85SYohann   }
389e201c85SYohann }
399e201c85SYohann 
409e201c85SYohann //------------------------------------------------------------------------------
419e201c85SYohann // Single element -> E-vector
429e201c85SYohann //------------------------------------------------------------------------------
439e201c85SYohann template <int NUM_COMP, int P_1D>
WriteElementStrided1d(SharedData_Hip & data,const CeedInt elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,const CeedScalar * r_v,CeedScalar * d_v)442b730f8bSJeremy L Thompson inline __device__ void WriteElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
452b730f8bSJeremy L Thompson                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
469e201c85SYohann   if (data.t_id_x < P_1D) {
479e201c85SYohann     const CeedInt node = data.t_id_x;
489e201c85SYohann     const CeedInt ind  = node * strides_node + elem * strides_elem;
49672b0f2aSSebastian Grimberg 
509e201c85SYohann     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
519e201c85SYohann       d_v[ind + comp * strides_comp] = r_v[comp];
529e201c85SYohann     }
539e201c85SYohann   }
549e201c85SYohann }
559e201c85SYohann 
56db2becc9SJeremy L Thompson template <int NUM_COMP, int P_1D>
SumElementStrided1d(SharedData_Hip & data,const CeedInt elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,const CeedScalar * r_v,CeedScalar * d_v)57db2becc9SJeremy L Thompson inline __device__ void SumElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
58db2becc9SJeremy L Thompson                                            const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
59db2becc9SJeremy L Thompson   if (data.t_id_x < P_1D) {
60db2becc9SJeremy L Thompson     const CeedInt node = data.t_id_x;
61db2becc9SJeremy L Thompson     const CeedInt ind  = node * strides_node + elem * strides_elem;
62db2becc9SJeremy L Thompson 
63db2becc9SJeremy L Thompson     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
64db2becc9SJeremy L Thompson       d_v[ind + comp * strides_comp] += r_v[comp];
65db2becc9SJeremy L Thompson     }
66db2becc9SJeremy L Thompson   }
67db2becc9SJeremy L Thompson }
68db2becc9SJeremy L Thompson 
699e201c85SYohann //------------------------------------------------------------------------------
709e201c85SYohann // 2D
719e201c85SYohann //------------------------------------------------------------------------------
729e201c85SYohann 
739e201c85SYohann //------------------------------------------------------------------------------
749e201c85SYohann // E-vector -> single element
759e201c85SYohann //------------------------------------------------------------------------------
769e201c85SYohann template <int NUM_COMP, int P_1D>
ReadElementStrided2d(SharedData_Hip & data,const CeedInt elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,const CeedScalar * __restrict__ d_u,CeedScalar * r_u)772b730f8bSJeremy L Thompson inline __device__ void ReadElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
782b730f8bSJeremy L Thompson                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
799e201c85SYohann   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
809e201c85SYohann     const CeedInt node = data.t_id_x + data.t_id_y * P_1D;
819e201c85SYohann     const CeedInt ind  = node * strides_node + elem * strides_elem;
82672b0f2aSSebastian Grimberg 
839e201c85SYohann     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
849e201c85SYohann       r_u[comp] = d_u[ind + comp * strides_comp];
859e201c85SYohann     }
869e201c85SYohann   }
879e201c85SYohann }
889e201c85SYohann 
899e201c85SYohann //------------------------------------------------------------------------------
909e201c85SYohann // Single element -> E-vector
919e201c85SYohann //------------------------------------------------------------------------------
929e201c85SYohann template <int NUM_COMP, int P_1D>
WriteElementStrided2d(SharedData_Hip & data,const CeedInt elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,const CeedScalar * r_v,CeedScalar * d_v)932b730f8bSJeremy L Thompson inline __device__ void WriteElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
942b730f8bSJeremy L Thompson                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
959e201c85SYohann   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
969e201c85SYohann     const CeedInt node = data.t_id_x + data.t_id_y * P_1D;
979e201c85SYohann     const CeedInt ind  = node * strides_node + elem * strides_elem;
98672b0f2aSSebastian Grimberg 
999e201c85SYohann     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
1009e201c85SYohann       d_v[ind + comp * strides_comp] = r_v[comp];
1019e201c85SYohann     }
1029e201c85SYohann   }
1039e201c85SYohann }
1049e201c85SYohann 
105db2becc9SJeremy L Thompson template <int NUM_COMP, int P_1D>
SumElementStrided2d(SharedData_Hip & data,const CeedInt elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,const CeedScalar * r_v,CeedScalar * d_v)106db2becc9SJeremy L Thompson inline __device__ void SumElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
107db2becc9SJeremy L Thompson                                            const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
108db2becc9SJeremy L Thompson   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
109db2becc9SJeremy L Thompson     const CeedInt node = data.t_id_x + data.t_id_y * P_1D;
110db2becc9SJeremy L Thompson     const CeedInt ind  = node * strides_node + elem * strides_elem;
111db2becc9SJeremy L Thompson 
112db2becc9SJeremy L Thompson     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
113db2becc9SJeremy L Thompson       d_v[ind + comp * strides_comp] += r_v[comp];
114db2becc9SJeremy L Thompson     }
115db2becc9SJeremy L Thompson   }
116db2becc9SJeremy L Thompson }
117db2becc9SJeremy L Thompson 
1189e201c85SYohann //------------------------------------------------------------------------------
1199e201c85SYohann // 3D
1209e201c85SYohann //------------------------------------------------------------------------------
1219e201c85SYohann 
1229e201c85SYohann //------------------------------------------------------------------------------
1239e201c85SYohann // E-vector -> single element
1249e201c85SYohann //------------------------------------------------------------------------------
1259e201c85SYohann template <int NUM_COMP, int P_1D>
ReadElementStrided3d(SharedData_Hip & data,const CeedInt elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,const CeedScalar * __restrict__ d_u,CeedScalar * r_u)1262b730f8bSJeremy L Thompson inline __device__ void ReadElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
1272b730f8bSJeremy L Thompson                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
1289e201c85SYohann   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
1299e201c85SYohann     for (CeedInt z = 0; z < P_1D; z++) {
1309e201c85SYohann       const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D;
1319e201c85SYohann       const CeedInt ind  = node * strides_node + elem * strides_elem;
132672b0f2aSSebastian Grimberg 
1339e201c85SYohann       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
1349e201c85SYohann         r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp];
1359e201c85SYohann       }
1369e201c85SYohann     }
1379e201c85SYohann   }
1389e201c85SYohann }
1399e201c85SYohann 
1409e201c85SYohann //------------------------------------------------------------------------------
1419e201c85SYohann // Single element -> E-vector
1429e201c85SYohann //------------------------------------------------------------------------------
1439e201c85SYohann template <int NUM_COMP, int P_1D>
WriteElementStrided3d(SharedData_Hip & data,const CeedInt elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,const CeedScalar * r_v,CeedScalar * d_v)1442b730f8bSJeremy L Thompson inline __device__ void WriteElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
1452b730f8bSJeremy L Thompson                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
1469e201c85SYohann   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
1479e201c85SYohann     for (CeedInt z = 0; z < P_1D; z++) {
1489e201c85SYohann       const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D;
1499e201c85SYohann       const CeedInt ind  = node * strides_node + elem * strides_elem;
150672b0f2aSSebastian Grimberg 
1519e201c85SYohann       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
1529e201c85SYohann         d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D];
1539e201c85SYohann       }
1549e201c85SYohann     }
1559e201c85SYohann   }
1569e201c85SYohann }
157db2becc9SJeremy L Thompson 
158db2becc9SJeremy L Thompson template <int NUM_COMP, int P_1D>
SumElementStrided3d(SharedData_Hip & data,const CeedInt elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,const CeedScalar * r_v,CeedScalar * d_v)159db2becc9SJeremy L Thompson inline __device__ void SumElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
160db2becc9SJeremy L Thompson                                            const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
161db2becc9SJeremy L Thompson   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
162db2becc9SJeremy L Thompson     for (CeedInt z = 0; z < P_1D; z++) {
163db2becc9SJeremy L Thompson       const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D;
164db2becc9SJeremy L Thompson       const CeedInt ind  = node * strides_node + elem * strides_elem;
165db2becc9SJeremy L Thompson 
166db2becc9SJeremy L Thompson       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
167db2becc9SJeremy L Thompson         d_v[ind + comp * strides_comp] += r_v[z + comp * P_1D];
168db2becc9SJeremy L Thompson       }
169db2becc9SJeremy L Thompson     }
170db2becc9SJeremy L Thompson   }
171db2becc9SJeremy L Thompson }
172b6a2eb79SJeremy L Thompson 
173b6a2eb79SJeremy L Thompson //------------------------------------------------------------------------------
174b6a2eb79SJeremy L Thompson // AtPoints
175b6a2eb79SJeremy L Thompson //------------------------------------------------------------------------------
176b6a2eb79SJeremy L Thompson 
177b6a2eb79SJeremy L Thompson //------------------------------------------------------------------------------
178b6a2eb79SJeremy L Thompson // E-vector -> single point
179b6a2eb79SJeremy L Thompson //------------------------------------------------------------------------------
180b6a2eb79SJeremy L Thompson template <int NUM_COMP, int NUM_PTS>
ReadPoint(SharedData_Hip & data,const CeedInt elem,const CeedInt p,const CeedInt points_in_elem,const CeedInt strides_point,const CeedInt strides_comp,const CeedInt strides_elem,const CeedScalar * __restrict__ d_u,CeedScalar * r_u)181b6a2eb79SJeremy L Thompson inline __device__ void ReadPoint(SharedData_Hip &data, const CeedInt elem, const CeedInt p, const CeedInt points_in_elem, const CeedInt strides_point,
182b6a2eb79SJeremy L Thompson                                  const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
183b6a2eb79SJeremy L Thompson   const CeedInt ind = (p % NUM_PTS) * strides_point + elem * strides_elem;
184b6a2eb79SJeremy L Thompson 
185b6a2eb79SJeremy L Thompson   if (p < points_in_elem) {
186b6a2eb79SJeremy L Thompson     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
187b6a2eb79SJeremy L Thompson       r_u[comp] = d_u[ind + comp * strides_comp];
188b6a2eb79SJeremy L Thompson     }
189b6a2eb79SJeremy L Thompson   } else {
190b6a2eb79SJeremy L Thompson     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
191b6a2eb79SJeremy L Thompson       r_u[comp] = 0.0;
192b6a2eb79SJeremy L Thompson     }
193b6a2eb79SJeremy L Thompson   }
194b6a2eb79SJeremy L Thompson }
195b6a2eb79SJeremy L Thompson 
196b6a2eb79SJeremy L Thompson //------------------------------------------------------------------------------
197b6a2eb79SJeremy L Thompson // Single point -> E-vector
198b6a2eb79SJeremy L Thompson //------------------------------------------------------------------------------
199b6a2eb79SJeremy L Thompson template <int NUM_COMP, int NUM_PTS>
WritePoint(SharedData_Hip & data,const CeedInt elem,const CeedInt p,const CeedInt points_in_elem,const CeedInt strides_point,const CeedInt strides_comp,const CeedInt strides_elem,const CeedScalar * r_v,CeedScalar * d_v)200b6a2eb79SJeremy L Thompson inline __device__ void WritePoint(SharedData_Hip &data, const CeedInt elem, const CeedInt p, const CeedInt points_in_elem,
201b6a2eb79SJeremy L Thompson                                   const CeedInt strides_point, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *r_v,
202b6a2eb79SJeremy L Thompson                                   CeedScalar *d_v) {
203b6a2eb79SJeremy L Thompson   if (p < points_in_elem) {
204b6a2eb79SJeremy L Thompson     const CeedInt ind = (p % NUM_PTS) * strides_point + elem * strides_elem;
205b6a2eb79SJeremy L Thompson 
206b6a2eb79SJeremy L Thompson     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
207b6a2eb79SJeremy L Thompson       d_v[ind + comp * strides_comp] = r_v[comp];
208b6a2eb79SJeremy L Thompson     }
209b6a2eb79SJeremy L Thompson   }
210b6a2eb79SJeremy L Thompson }
211