xref: /libCEED/include/ceed/jit-source/hip/hip-shared-basis-read-write-templates.h (revision 9c25dd66b9687765a7022cc762ccaf201b721845)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for HIP shared memory basis read/write templates
10 
11 #include <ceed/types.h>
12 
13 //------------------------------------------------------------------------------
14 // Helper function: load matrices for basis actions
15 //------------------------------------------------------------------------------
16 template <int SIZE>
17 inline __device__ void loadMatrix(const CeedScalar *d_B, CeedScalar *B) {
18   CeedInt tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
19 
20   for (CeedInt i = tid; i < SIZE; i += blockDim.x * blockDim.y * blockDim.z) B[i] = d_B[i];
21 }
22 
23 //------------------------------------------------------------------------------
24 // 1D
25 //------------------------------------------------------------------------------
26 
27 //------------------------------------------------------------------------------
28 // E-vector -> single element
29 //------------------------------------------------------------------------------
30 template <int NUM_COMP, int P_1D>
31 inline __device__ void ReadElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
32                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
33   if (data.t_id_x < P_1D) {
34     const CeedInt node = data.t_id_x;
35     const CeedInt ind  = node * strides_node + elem * strides_elem;
36 
37     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
38       r_u[comp] = d_u[ind + comp * strides_comp];
39     }
40   }
41 }
42 
43 //------------------------------------------------------------------------------
44 // Single element -> E-vector
45 //------------------------------------------------------------------------------
46 template <int NUM_COMP, int P_1D>
47 inline __device__ void WriteElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
48                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
49   if (data.t_id_x < P_1D) {
50     const CeedInt node = data.t_id_x;
51     const CeedInt ind  = node * strides_node + elem * strides_elem;
52 
53     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
54       d_v[ind + comp * strides_comp] = r_v[comp];
55     }
56   }
57 }
58 
59 template <int NUM_COMP, int P_1D>
60 inline __device__ void SumElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
61                                            const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
62   if (data.t_id_x < P_1D) {
63     const CeedInt node = data.t_id_x;
64     const CeedInt ind  = node * strides_node + elem * strides_elem;
65 
66     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
67       d_v[ind + comp * strides_comp] += r_v[comp];
68     }
69   }
70 }
71 
72 //------------------------------------------------------------------------------
73 // 2D
74 //------------------------------------------------------------------------------
75 
76 //------------------------------------------------------------------------------
77 // E-vector -> single element
78 //------------------------------------------------------------------------------
79 template <int NUM_COMP, int P_1D>
80 inline __device__ void ReadElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
81                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
82   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
83     const CeedInt node = data.t_id_x + data.t_id_y * P_1D;
84     const CeedInt ind  = node * strides_node + elem * strides_elem;
85 
86     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
87       r_u[comp] = d_u[ind + comp * strides_comp];
88     }
89   }
90 }
91 
92 //------------------------------------------------------------------------------
93 // Single element -> E-vector
94 //------------------------------------------------------------------------------
95 template <int NUM_COMP, int P_1D>
96 inline __device__ void WriteElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
97                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
98   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
99     const CeedInt node = data.t_id_x + data.t_id_y * P_1D;
100     const CeedInt ind  = node * strides_node + elem * strides_elem;
101 
102     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
103       d_v[ind + comp * strides_comp] = r_v[comp];
104     }
105   }
106 }
107 
108 template <int NUM_COMP, int P_1D>
109 inline __device__ void SumElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
110                                            const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
111   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
112     const CeedInt node = data.t_id_x + data.t_id_y * P_1D;
113     const CeedInt ind  = node * strides_node + elem * strides_elem;
114 
115     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
116       d_v[ind + comp * strides_comp] += r_v[comp];
117     }
118   }
119 }
120 
121 //------------------------------------------------------------------------------
122 // 3D
123 //------------------------------------------------------------------------------
124 
125 //------------------------------------------------------------------------------
126 // E-vector -> single element
127 //------------------------------------------------------------------------------
128 template <int NUM_COMP, int P_1D>
129 inline __device__ void ReadElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
130                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
131   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
132     for (CeedInt z = 0; z < P_1D; z++) {
133       const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D;
134       const CeedInt ind  = node * strides_node + elem * strides_elem;
135 
136       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
137         r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp];
138       }
139     }
140   }
141 }
142 
143 //------------------------------------------------------------------------------
144 // Single element -> E-vector
145 //------------------------------------------------------------------------------
146 template <int NUM_COMP, int P_1D>
147 inline __device__ void WriteElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
148                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
149   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
150     for (CeedInt z = 0; z < P_1D; z++) {
151       const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D;
152       const CeedInt ind  = node * strides_node + elem * strides_elem;
153 
154       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
155         d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D];
156       }
157     }
158   }
159 }
160 
161 template <int NUM_COMP, int P_1D>
162 inline __device__ void SumElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
163                                            const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
164   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
165     for (CeedInt z = 0; z < P_1D; z++) {
166       const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D;
167       const CeedInt ind  = node * strides_node + elem * strides_elem;
168 
169       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
170         d_v[ind + comp * strides_comp] += r_v[z + comp * P_1D];
171       }
172     }
173   }
174 }
175