xref: /libCEED/include/ceed/jit-source/hip/hip-shared-basis-read-write-templates.h (revision b6a2eb7998676e206f6df72229aa5643127bbcef)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for HIP shared memory basis read/write templates
10 #include <ceed/types.h>
11 
12 //------------------------------------------------------------------------------
13 // Helper function: load matrices for basis actions
14 //------------------------------------------------------------------------------
15 template <int SIZE>
16 inline __device__ void loadMatrix(const CeedScalar *d_B, CeedScalar *B) {
17   CeedInt tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
18 
19   for (CeedInt i = tid; i < SIZE; i += blockDim.x * blockDim.y * blockDim.z) B[i] = d_B[i];
20 }
21 
22 //------------------------------------------------------------------------------
23 // 1D
24 //------------------------------------------------------------------------------
25 
26 //------------------------------------------------------------------------------
27 // E-vector -> single element
28 //------------------------------------------------------------------------------
29 template <int NUM_COMP, int P_1D>
30 inline __device__ void ReadElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
31                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
32   if (data.t_id_x < P_1D) {
33     const CeedInt node = data.t_id_x;
34     const CeedInt ind  = node * strides_node + elem * strides_elem;
35 
36     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
37       r_u[comp] = d_u[ind + comp * strides_comp];
38     }
39   }
40 }
41 
42 //------------------------------------------------------------------------------
43 // Single element -> E-vector
44 //------------------------------------------------------------------------------
45 template <int NUM_COMP, int P_1D>
46 inline __device__ void WriteElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
47                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
48   if (data.t_id_x < P_1D) {
49     const CeedInt node = data.t_id_x;
50     const CeedInt ind  = node * strides_node + elem * strides_elem;
51 
52     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
53       d_v[ind + comp * strides_comp] = r_v[comp];
54     }
55   }
56 }
57 
58 template <int NUM_COMP, int P_1D>
59 inline __device__ void SumElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
60                                            const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
61   if (data.t_id_x < P_1D) {
62     const CeedInt node = data.t_id_x;
63     const CeedInt ind  = node * strides_node + elem * strides_elem;
64 
65     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
66       d_v[ind + comp * strides_comp] += r_v[comp];
67     }
68   }
69 }
70 
71 //------------------------------------------------------------------------------
72 // 2D
73 //------------------------------------------------------------------------------
74 
75 //------------------------------------------------------------------------------
76 // E-vector -> single element
77 //------------------------------------------------------------------------------
78 template <int NUM_COMP, int P_1D>
79 inline __device__ void ReadElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
80                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
81   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
82     const CeedInt node = data.t_id_x + data.t_id_y * P_1D;
83     const CeedInt ind  = node * strides_node + elem * strides_elem;
84 
85     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
86       r_u[comp] = d_u[ind + comp * strides_comp];
87     }
88   }
89 }
90 
91 //------------------------------------------------------------------------------
92 // Single element -> E-vector
93 //------------------------------------------------------------------------------
94 template <int NUM_COMP, int P_1D>
95 inline __device__ void WriteElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
96                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
97   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
98     const CeedInt node = data.t_id_x + data.t_id_y * P_1D;
99     const CeedInt ind  = node * strides_node + elem * strides_elem;
100 
101     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
102       d_v[ind + comp * strides_comp] = r_v[comp];
103     }
104   }
105 }
106 
107 template <int NUM_COMP, int P_1D>
108 inline __device__ void SumElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
109                                            const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
110   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
111     const CeedInt node = data.t_id_x + data.t_id_y * P_1D;
112     const CeedInt ind  = node * strides_node + elem * strides_elem;
113 
114     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
115       d_v[ind + comp * strides_comp] += r_v[comp];
116     }
117   }
118 }
119 
120 //------------------------------------------------------------------------------
121 // 3D
122 //------------------------------------------------------------------------------
123 
124 //------------------------------------------------------------------------------
125 // E-vector -> single element
126 //------------------------------------------------------------------------------
127 template <int NUM_COMP, int P_1D>
128 inline __device__ void ReadElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
129                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
130   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
131     for (CeedInt z = 0; z < P_1D; z++) {
132       const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D;
133       const CeedInt ind  = node * strides_node + elem * strides_elem;
134 
135       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
136         r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp];
137       }
138     }
139   }
140 }
141 
142 //------------------------------------------------------------------------------
143 // Single element -> E-vector
144 //------------------------------------------------------------------------------
145 template <int NUM_COMP, int P_1D>
146 inline __device__ void WriteElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
147                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
148   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
149     for (CeedInt z = 0; z < P_1D; z++) {
150       const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D;
151       const CeedInt ind  = node * strides_node + elem * strides_elem;
152 
153       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
154         d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D];
155       }
156     }
157   }
158 }
159 
160 template <int NUM_COMP, int P_1D>
161 inline __device__ void SumElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
162                                            const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
163   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
164     for (CeedInt z = 0; z < P_1D; z++) {
165       const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D;
166       const CeedInt ind  = node * strides_node + elem * strides_elem;
167 
168       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
169         d_v[ind + comp * strides_comp] += r_v[z + comp * P_1D];
170       }
171     }
172   }
173 }
174 
175 //------------------------------------------------------------------------------
176 // AtPoints
177 //------------------------------------------------------------------------------
178 
179 //------------------------------------------------------------------------------
180 // E-vector -> single point
181 //------------------------------------------------------------------------------
182 template <int NUM_COMP, int NUM_PTS>
183 inline __device__ void ReadPoint(SharedData_Hip &data, const CeedInt elem, const CeedInt p, const CeedInt points_in_elem, const CeedInt strides_point,
184                                  const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
185   const CeedInt ind = (p % NUM_PTS) * strides_point + elem * strides_elem;
186 
187   if (p < points_in_elem) {
188     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
189       r_u[comp] = d_u[ind + comp * strides_comp];
190     }
191   } else {
192     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
193       r_u[comp] = 0.0;
194     }
195   }
196 }
197 
198 //------------------------------------------------------------------------------
199 // Single point -> E-vector
200 //------------------------------------------------------------------------------
201 template <int NUM_COMP, int NUM_PTS>
202 inline __device__ void WritePoint(SharedData_Hip &data, const CeedInt elem, const CeedInt p, const CeedInt points_in_elem,
203                                   const CeedInt strides_point, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *r_v,
204                                   CeedScalar *d_v) {
205   if (p < points_in_elem) {
206     const CeedInt ind = (p % NUM_PTS) * strides_point + elem * strides_elem;
207 
208     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
209       d_v[ind + comp * strides_comp] = r_v[comp];
210     }
211   }
212 }
213