xref: /libCEED/include/ceed/jit-source/hip/hip-shared-basis-read-write-templates.h (revision d4cc18453651bd0f94c1a2e078b2646a92dafdcc) !
1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for HIP shared memory basis read/write templates
10 #include <ceed/types.h>
11 
12 //------------------------------------------------------------------------------
13 // Helper function: load matrices for basis actions
14 //------------------------------------------------------------------------------
15 template <int P, int Q>
LoadMatrix(SharedData_Hip & data,const CeedScalar * __restrict__ d_B,CeedScalar * B)16 inline __device__ void LoadMatrix(SharedData_Hip &data, const CeedScalar *__restrict__ d_B, CeedScalar *B) {
17   for (CeedInt i = data.t_id; i < P * Q; i += blockDim.x * blockDim.y * blockDim.z) B[i] = d_B[i];
18 }
19 
20 //------------------------------------------------------------------------------
21 // 1D
22 //------------------------------------------------------------------------------
23 
24 //------------------------------------------------------------------------------
25 // E-vector -> single element
26 //------------------------------------------------------------------------------
27 template <int NUM_COMP, int P_1D>
ReadElementStrided1d(SharedData_Hip & data,const CeedInt elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,const CeedScalar * __restrict__ d_u,CeedScalar * r_u)28 inline __device__ void ReadElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
29                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
30   if (data.t_id_x < P_1D) {
31     const CeedInt node = data.t_id_x;
32     const CeedInt ind  = node * strides_node + elem * strides_elem;
33 
34     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
35       r_u[comp] = d_u[ind + comp * strides_comp];
36     }
37   }
38 }
39 
40 //------------------------------------------------------------------------------
41 // Single element -> E-vector
42 //------------------------------------------------------------------------------
43 template <int NUM_COMP, int P_1D>
WriteElementStrided1d(SharedData_Hip & data,const CeedInt elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,const CeedScalar * r_v,CeedScalar * d_v)44 inline __device__ void WriteElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
45                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
46   if (data.t_id_x < P_1D) {
47     const CeedInt node = data.t_id_x;
48     const CeedInt ind  = node * strides_node + elem * strides_elem;
49 
50     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
51       d_v[ind + comp * strides_comp] = r_v[comp];
52     }
53   }
54 }
55 
56 template <int NUM_COMP, int P_1D>
SumElementStrided1d(SharedData_Hip & data,const CeedInt elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,const CeedScalar * r_v,CeedScalar * d_v)57 inline __device__ void SumElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
58                                            const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
59   if (data.t_id_x < P_1D) {
60     const CeedInt node = data.t_id_x;
61     const CeedInt ind  = node * strides_node + elem * strides_elem;
62 
63     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
64       d_v[ind + comp * strides_comp] += r_v[comp];
65     }
66   }
67 }
68 
69 //------------------------------------------------------------------------------
70 // 2D
71 //------------------------------------------------------------------------------
72 
73 //------------------------------------------------------------------------------
74 // E-vector -> single element
75 //------------------------------------------------------------------------------
76 template <int NUM_COMP, int P_1D>
ReadElementStrided2d(SharedData_Hip & data,const CeedInt elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,const CeedScalar * __restrict__ d_u,CeedScalar * r_u)77 inline __device__ void ReadElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
78                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
79   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
80     const CeedInt node = data.t_id_x + data.t_id_y * P_1D;
81     const CeedInt ind  = node * strides_node + elem * strides_elem;
82 
83     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
84       r_u[comp] = d_u[ind + comp * strides_comp];
85     }
86   }
87 }
88 
89 //------------------------------------------------------------------------------
90 // Single element -> E-vector
91 //------------------------------------------------------------------------------
92 template <int NUM_COMP, int P_1D>
WriteElementStrided2d(SharedData_Hip & data,const CeedInt elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,const CeedScalar * r_v,CeedScalar * d_v)93 inline __device__ void WriteElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
94                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
95   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
96     const CeedInt node = data.t_id_x + data.t_id_y * P_1D;
97     const CeedInt ind  = node * strides_node + elem * strides_elem;
98 
99     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
100       d_v[ind + comp * strides_comp] = r_v[comp];
101     }
102   }
103 }
104 
105 template <int NUM_COMP, int P_1D>
SumElementStrided2d(SharedData_Hip & data,const CeedInt elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,const CeedScalar * r_v,CeedScalar * d_v)106 inline __device__ void SumElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
107                                            const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
108   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
109     const CeedInt node = data.t_id_x + data.t_id_y * P_1D;
110     const CeedInt ind  = node * strides_node + elem * strides_elem;
111 
112     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
113       d_v[ind + comp * strides_comp] += r_v[comp];
114     }
115   }
116 }
117 
118 //------------------------------------------------------------------------------
119 // 3D
120 //------------------------------------------------------------------------------
121 
122 //------------------------------------------------------------------------------
123 // E-vector -> single element
124 //------------------------------------------------------------------------------
125 template <int NUM_COMP, int P_1D>
ReadElementStrided3d(SharedData_Hip & data,const CeedInt elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,const CeedScalar * __restrict__ d_u,CeedScalar * r_u)126 inline __device__ void ReadElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
127                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
128   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
129     for (CeedInt z = 0; z < P_1D; z++) {
130       const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D;
131       const CeedInt ind  = node * strides_node + elem * strides_elem;
132 
133       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
134         r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp];
135       }
136     }
137   }
138 }
139 
140 //------------------------------------------------------------------------------
141 // Single element -> E-vector
142 //------------------------------------------------------------------------------
143 template <int NUM_COMP, int P_1D>
WriteElementStrided3d(SharedData_Hip & data,const CeedInt elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,const CeedScalar * r_v,CeedScalar * d_v)144 inline __device__ void WriteElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
145                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
146   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
147     for (CeedInt z = 0; z < P_1D; z++) {
148       const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D;
149       const CeedInt ind  = node * strides_node + elem * strides_elem;
150 
151       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
152         d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D];
153       }
154     }
155   }
156 }
157 
158 template <int NUM_COMP, int P_1D>
SumElementStrided3d(SharedData_Hip & data,const CeedInt elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,const CeedScalar * r_v,CeedScalar * d_v)159 inline __device__ void SumElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
160                                            const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
161   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
162     for (CeedInt z = 0; z < P_1D; z++) {
163       const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D;
164       const CeedInt ind  = node * strides_node + elem * strides_elem;
165 
166       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
167         d_v[ind + comp * strides_comp] += r_v[z + comp * P_1D];
168       }
169     }
170   }
171 }
172 
173 //------------------------------------------------------------------------------
174 // AtPoints
175 //------------------------------------------------------------------------------
176 
177 //------------------------------------------------------------------------------
178 // E-vector -> single point
179 //------------------------------------------------------------------------------
180 template <int NUM_COMP, int NUM_PTS>
ReadPoint(SharedData_Hip & data,const CeedInt elem,const CeedInt p,const CeedInt points_in_elem,const CeedInt strides_point,const CeedInt strides_comp,const CeedInt strides_elem,const CeedScalar * __restrict__ d_u,CeedScalar * r_u)181 inline __device__ void ReadPoint(SharedData_Hip &data, const CeedInt elem, const CeedInt p, const CeedInt points_in_elem, const CeedInt strides_point,
182                                  const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
183   const CeedInt ind = (p % NUM_PTS) * strides_point + elem * strides_elem;
184 
185   if (p < points_in_elem) {
186     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
187       r_u[comp] = d_u[ind + comp * strides_comp];
188     }
189   } else {
190     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
191       r_u[comp] = 0.0;
192     }
193   }
194 }
195 
196 //------------------------------------------------------------------------------
197 // Single point -> E-vector
198 //------------------------------------------------------------------------------
199 template <int NUM_COMP, int NUM_PTS>
WritePoint(SharedData_Hip & data,const CeedInt elem,const CeedInt p,const CeedInt points_in_elem,const CeedInt strides_point,const CeedInt strides_comp,const CeedInt strides_elem,const CeedScalar * r_v,CeedScalar * d_v)200 inline __device__ void WritePoint(SharedData_Hip &data, const CeedInt elem, const CeedInt p, const CeedInt points_in_elem,
201                                   const CeedInt strides_point, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *r_v,
202                                   CeedScalar *d_v) {
203   if (p < points_in_elem) {
204     const CeedInt ind = (p % NUM_PTS) * strides_point + elem * strides_elem;
205 
206     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
207       d_v[ind + comp * strides_comp] = r_v[comp];
208     }
209   }
210 }
211