xref: /libCEED/include/ceed/jit-source/cuda/cuda-shared-basis-tensor.h (revision 4753b775a3a8f79e2dd83c2aab10890a6a04e913)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for CUDA shared memory tensor product basis
10 #include <ceed/types.h>
11 
12 #include "cuda-shared-basis-read-write-templates.h"
13 #include "cuda-shared-basis-tensor-templates.h"
14 
15 //------------------------------------------------------------------------------
16 // Interp kernel by dim
17 //------------------------------------------------------------------------------
18 extern "C" __global__ void Interp(const CeedInt num_elem, const CeedScalar *c_B, const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) {
19   extern __shared__ CeedScalar slice[];
20 
21   SharedData_Cuda data;
22   data.t_id_x = threadIdx.x;
23   data.t_id_y = threadIdx.y;
24   data.t_id_z = threadIdx.z;
25   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
26   data.slice  = slice + data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
27 
28   CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
29   CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
30 
31   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
32     if (BASIS_DIM == 1) {
33       ReadElementStrided1d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * num_elem, BASIS_P_1D, d_U, r_U);
34       Interp1d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, r_V);
35       WriteElementStrided1d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * num_elem, BASIS_Q_1D, r_V, d_V);
36     } else if (BASIS_DIM == 2) {
37       ReadElementStrided2d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * num_elem, BASIS_P_1D * BASIS_P_1D, d_U, r_U);
38       InterpTensor2d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, r_V);
39       WriteElementStrided2d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * num_elem, BASIS_Q_1D * BASIS_Q_1D, r_V, d_V);
40     } else if (BASIS_DIM == 3) {
41       ReadElementStrided3d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * BASIS_P_1D * num_elem,
42                                                        BASIS_P_1D * BASIS_P_1D * BASIS_P_1D, d_U, r_U);
43       InterpTensor3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, r_V);
44       WriteElementStrided3d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D * num_elem,
45                                                         BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D, r_V, d_V);
46     }
47   }
48 }
49 
50 extern "C" __global__ void InterpTranspose(const CeedInt num_elem, const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
51                                            CeedScalar *__restrict__ d_V) {
52   extern __shared__ CeedScalar slice[];
53 
54   SharedData_Cuda data;
55   data.t_id_x = threadIdx.x;
56   data.t_id_y = threadIdx.y;
57   data.t_id_z = threadIdx.z;
58   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
59   data.slice  = slice + data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
60 
61   CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
62   CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
63 
64   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
65     if (BASIS_DIM == 1) {
66       ReadElementStrided1d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * num_elem, BASIS_Q_1D, d_U, r_U);
67       InterpTranspose1d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, r_V);
68       WriteElementStrided1d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * num_elem, BASIS_P_1D, r_V, d_V);
69     } else if (BASIS_DIM == 2) {
70       ReadElementStrided2d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * num_elem, BASIS_Q_1D * BASIS_Q_1D, d_U, r_U);
71       InterpTransposeTensor2d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, r_V);
72       WriteElementStrided2d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * num_elem, BASIS_P_1D * BASIS_P_1D, r_V, d_V);
73     } else if (BASIS_DIM == 3) {
74       ReadElementStrided3d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D * num_elem,
75                                                        BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D, d_U, r_U);
76       InterpTransposeTensor3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, r_V);
77       WriteElementStrided3d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * BASIS_P_1D * num_elem,
78                                                         BASIS_P_1D * BASIS_P_1D * BASIS_P_1D, r_V, d_V);
79     }
80   }
81 }
82 
83 extern "C" __global__ void InterpTransposeAdd(const CeedInt num_elem, const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
84                                               CeedScalar *__restrict__ d_V) {
85   extern __shared__ CeedScalar slice[];
86 
87   SharedData_Cuda data;
88   data.t_id_x = threadIdx.x;
89   data.t_id_y = threadIdx.y;
90   data.t_id_z = threadIdx.z;
91   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
92   data.slice  = slice + data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
93 
94   CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
95   CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
96 
97   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
98     if (BASIS_DIM == 1) {
99       ReadElementStrided1d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * num_elem, BASIS_Q_1D, d_U, r_U);
100       InterpTranspose1d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, r_V);
101       SumElementStrided1d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * num_elem, BASIS_P_1D, r_V, d_V);
102     } else if (BASIS_DIM == 2) {
103       ReadElementStrided2d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * num_elem, BASIS_Q_1D * BASIS_Q_1D, d_U, r_U);
104       InterpTransposeTensor2d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, r_V);
105       SumElementStrided2d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * num_elem, BASIS_P_1D * BASIS_P_1D, r_V, d_V);
106     } else if (BASIS_DIM == 3) {
107       ReadElementStrided3d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D * num_elem,
108                                                        BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D, d_U, r_U);
109       InterpTransposeTensor3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, r_V);
110       SumElementStrided3d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * BASIS_P_1D * num_elem,
111                                                       BASIS_P_1D * BASIS_P_1D * BASIS_P_1D, r_V, d_V);
112     }
113   }
114 }
115 
116 //------------------------------------------------------------------------------
117 // Grad kernel by dim
118 //------------------------------------------------------------------------------
119 extern "C" __global__ void Grad(const CeedInt num_elem, const CeedScalar *c_B, const CeedScalar *c_G, const CeedScalar *__restrict__ d_U,
120                                 CeedScalar *__restrict__ d_V) {
121   extern __shared__ CeedScalar slice[];
122 
123   SharedData_Cuda data;
124   data.t_id_x = threadIdx.x;
125   data.t_id_y = threadIdx.y;
126   data.t_id_z = threadIdx.z;
127   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
128   data.slice  = slice + data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
129 
130   CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
131   CeedScalar r_V[BASIS_NUM_COMP * BASIS_DIM * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
132 
133   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
134     if (BASIS_DIM == 1) {
135       ReadElementStrided1d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * num_elem, BASIS_P_1D, d_U, r_U);
136       Grad1d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
137       WriteElementStrided1d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * num_elem, BASIS_Q_1D, r_V, d_V);
138     } else if (BASIS_DIM == 2) {
139       ReadElementStrided2d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * num_elem, BASIS_P_1D * BASIS_P_1D, d_U, r_U);
140       GradTensor2d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
141       WriteElementStrided2d<BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * num_elem, BASIS_Q_1D * BASIS_Q_1D, r_V,
142                                                                     d_V);
143     } else if (BASIS_DIM == 3) {
144       ReadElementStrided3d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * BASIS_P_1D * num_elem,
145                                                        BASIS_P_1D * BASIS_P_1D * BASIS_P_1D, d_U, r_U);
146       if (BASIS_HAS_COLLOCATED_GRAD) GradTensorCollocated3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
147       else GradTensor3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
148       WriteElementStrided3d<BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D * num_elem,
149                                                                     BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D, r_V, d_V);
150     }
151   }
152 }
153 
154 extern "C" __global__ void GradTranspose(const CeedInt num_elem, const CeedScalar *c_B, const CeedScalar *c_G, const CeedScalar *__restrict__ d_U,
155                                          CeedScalar *__restrict__ d_V) {
156   extern __shared__ CeedScalar slice[];
157 
158   SharedData_Cuda data;
159   data.t_id_x = threadIdx.x;
160   data.t_id_y = threadIdx.y;
161   data.t_id_z = threadIdx.z;
162   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
163   data.slice  = slice + data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
164 
165   CeedScalar r_U[BASIS_NUM_COMP * BASIS_DIM * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
166   CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
167 
168   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
169     if (BASIS_DIM == 1) {
170       ReadElementStrided1d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * num_elem, BASIS_Q_1D, d_U, r_U);
171       GradTranspose1d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
172       WriteElementStrided1d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * num_elem, BASIS_P_1D, r_V, d_V);
173     } else if (BASIS_DIM == 2) {
174       ReadElementStrided2d<BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * num_elem, BASIS_Q_1D * BASIS_Q_1D, d_U,
175                                                                    r_U);
176       GradTransposeTensor2d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
177       WriteElementStrided2d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * num_elem, BASIS_P_1D * BASIS_P_1D, r_V, d_V);
178     } else if (BASIS_DIM == 3) {
179       ReadElementStrided3d<BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D * num_elem,
180                                                                    BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D, d_U, r_U);
181       if (BASIS_HAS_COLLOCATED_GRAD) GradTransposeTensorCollocated3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
182       else GradTransposeTensor3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
183       WriteElementStrided3d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * BASIS_P_1D * num_elem,
184                                                         BASIS_P_1D * BASIS_P_1D * BASIS_P_1D, r_V, d_V);
185     }
186   }
187 }
188 
189 extern "C" __global__ void GradTransposeAdd(const CeedInt num_elem, const CeedScalar *c_B, const CeedScalar *c_G, const CeedScalar *__restrict__ d_U,
190                                             CeedScalar *__restrict__ d_V) {
191   extern __shared__ CeedScalar slice[];
192 
193   SharedData_Cuda data;
194   data.t_id_x = threadIdx.x;
195   data.t_id_y = threadIdx.y;
196   data.t_id_z = threadIdx.z;
197   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
198   data.slice  = slice + data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
199 
200   CeedScalar r_U[BASIS_NUM_COMP * BASIS_DIM * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
201   CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
202 
203   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
204     if (BASIS_DIM == 1) {
205       ReadElementStrided1d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * num_elem, BASIS_Q_1D, d_U, r_U);
206       GradTranspose1d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
207       SumElementStrided1d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * num_elem, BASIS_P_1D, r_V, d_V);
208     } else if (BASIS_DIM == 2) {
209       ReadElementStrided2d<BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * num_elem, BASIS_Q_1D * BASIS_Q_1D, d_U,
210                                                                    r_U);
211       GradTransposeTensor2d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
212       SumElementStrided2d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * num_elem, BASIS_P_1D * BASIS_P_1D, r_V, d_V);
213     } else if (BASIS_DIM == 3) {
214       ReadElementStrided3d<BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D * num_elem,
215                                                                    BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D, d_U, r_U);
216       if (BASIS_HAS_COLLOCATED_GRAD) GradTransposeTensorCollocated3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
217       else GradTransposeTensor3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
218       SumElementStrided3d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * BASIS_P_1D * num_elem,
219                                                       BASIS_P_1D * BASIS_P_1D * BASIS_P_1D, r_V, d_V);
220     }
221   }
222 }
223 
224 //------------------------------------------------------------------------------
225 // Weight kernels by dim
226 //------------------------------------------------------------------------------
227 extern "C" __global__ void Weight(const CeedInt num_elem, const CeedScalar *__restrict__ q_weight_1d, CeedScalar *__restrict__ d_W) {
228   extern __shared__ CeedScalar slice[];
229 
230   SharedData_Cuda data;
231   data.t_id_x = threadIdx.x;
232   data.t_id_y = threadIdx.y;
233   data.t_id_z = threadIdx.z;
234   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
235   data.slice  = slice + data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
236 
237   CeedScalar r_W[BASIS_DIM > 2 ? BASIS_Q_1D : 1];
238 
239   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
240     if (BASIS_DIM == 1) {
241       Weight1d<BASIS_Q_1D>(data, q_weight_1d, r_W);
242       WriteElementStrided1d<1, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * num_elem, BASIS_Q_1D, r_W, d_W);
243     } else if (BASIS_DIM == 2) {
244       WeightTensor2d<BASIS_Q_1D>(data, q_weight_1d, r_W);
245       WriteElementStrided2d<1, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * num_elem, BASIS_Q_1D * BASIS_Q_1D, r_W, d_W);
246     } else if (BASIS_DIM == 3) {
247       WeightTensor3d<BASIS_Q_1D>(data, q_weight_1d, r_W);
248       WriteElementStrided3d<1, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D * num_elem, BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D, r_W,
249                                            d_W);
250     }
251   }
252 }
253