xref: /libCEED/include/ceed/jit-source/cuda/cuda-shared-basis-tensor.h (revision 81b0e02b1dff3ada0066e6e9e82451c10717e5c1)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for CUDA shared memory tensor product basis
10 
11 #include <ceed.h>
12 
13 #include "cuda-shared-basis-read-write-templates.h"
14 #include "cuda-shared-basis-tensor-templates.h"
15 
16 //------------------------------------------------------------------------------
17 // Interp kernel by dim
18 //------------------------------------------------------------------------------
19 extern "C" __global__ void Interp(const CeedInt num_elem, const CeedScalar *c_B, const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) {
20   extern __shared__ CeedScalar slice[];
21 
22   SharedData_Cuda data;
23   data.t_id_x = threadIdx.x;
24   data.t_id_y = threadIdx.y;
25   data.t_id_z = threadIdx.z;
26   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
27   data.slice  = slice + data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
28 
29   CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
30   CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
31 
32   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
33     if (BASIS_DIM == 1) {
34       ReadElementStrided1d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * num_elem, BASIS_P_1D, d_U, r_U);
35       Interp1d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, r_V);
36       WriteElementStrided1d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * num_elem, BASIS_Q_1D, r_V, d_V);
37     } else if (BASIS_DIM == 2) {
38       ReadElementStrided2d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * num_elem, BASIS_P_1D * BASIS_P_1D, d_U, r_U);
39       InterpTensor2d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, r_V);
40       WriteElementStrided2d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * num_elem, BASIS_Q_1D * BASIS_Q_1D, r_V, d_V);
41     } else if (BASIS_DIM == 3) {
42       ReadElementStrided3d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * BASIS_P_1D * num_elem,
43                                                        BASIS_P_1D * BASIS_P_1D * BASIS_P_1D, d_U, r_U);
44       InterpTensor3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, r_V);
45       WriteElementStrided3d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D * num_elem,
46                                                         BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D, r_V, d_V);
47     }
48   }
49 }
50 
51 extern "C" __global__ void InterpTranspose(const CeedInt num_elem, const CeedScalar *c_B, const CeedScalar *__restrict__ d_U,
52                                            CeedScalar *__restrict__ d_V) {
53   extern __shared__ CeedScalar slice[];
54 
55   SharedData_Cuda data;
56   data.t_id_x = threadIdx.x;
57   data.t_id_y = threadIdx.y;
58   data.t_id_z = threadIdx.z;
59   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
60   data.slice  = slice + data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
61 
62   CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
63   CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
64 
65   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
66     if (BASIS_DIM == 1) {
67       ReadElementStrided1d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * num_elem, BASIS_Q_1D, d_U, r_U);
68       InterpTranspose1d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, r_V);
69       WriteElementStrided1d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * num_elem, BASIS_P_1D, r_V, d_V);
70     } else if (BASIS_DIM == 2) {
71       ReadElementStrided2d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * num_elem, BASIS_Q_1D * BASIS_Q_1D, d_U, r_U);
72       InterpTransposeTensor2d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, r_V);
73       WriteElementStrided2d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * num_elem, BASIS_P_1D * BASIS_P_1D, r_V, d_V);
74     } else if (BASIS_DIM == 3) {
75       ReadElementStrided3d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D * num_elem,
76                                                        BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D, d_U, r_U);
77       InterpTransposeTensor3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, r_V);
78       WriteElementStrided3d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * BASIS_P_1D * num_elem,
79                                                         BASIS_P_1D * BASIS_P_1D * BASIS_P_1D, r_V, d_V);
80     }
81   }
82 }
83 
84 //------------------------------------------------------------------------------
85 // Grad kernel by dim
86 //------------------------------------------------------------------------------
87 extern "C" __global__ void Grad(const CeedInt num_elem, const CeedScalar *c_B, const CeedScalar *c_G, const CeedScalar *__restrict__ d_U,
88                                 CeedScalar *__restrict__ d_V) {
89   extern __shared__ CeedScalar slice[];
90 
91   SharedData_Cuda data;
92   data.t_id_x = threadIdx.x;
93   data.t_id_y = threadIdx.y;
94   data.t_id_z = threadIdx.z;
95   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
96   data.slice  = slice + data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
97 
98   CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
99   CeedScalar r_V[BASIS_NUM_COMP * BASIS_DIM * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
100 
101   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
102     if (BASIS_DIM == 1) {
103       ReadElementStrided1d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * num_elem, BASIS_P_1D, d_U, r_U);
104       Grad1d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
105       WriteElementStrided1d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * num_elem, BASIS_Q_1D, r_V, d_V);
106     } else if (BASIS_DIM == 2) {
107       ReadElementStrided2d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * num_elem, BASIS_P_1D * BASIS_P_1D, d_U, r_U);
108       GradTensor2d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
109       WriteElementStrided2d<BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * num_elem, BASIS_Q_1D * BASIS_Q_1D, r_V,
110                                                                     d_V);
111     } else if (BASIS_DIM == 3) {
112       ReadElementStrided3d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * BASIS_P_1D * num_elem,
113                                                        BASIS_P_1D * BASIS_P_1D * BASIS_P_1D, d_U, r_U);
114       if (BASIS_HAS_COLLOCATED_GRAD) GradTensorCollocated3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
115       else GradTensor3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
116       WriteElementStrided3d<BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D * num_elem,
117                                                                     BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D, r_V, d_V);
118     }
119   }
120 }
121 
122 extern "C" __global__ void GradTranspose(const CeedInt num_elem, const CeedScalar *c_B, const CeedScalar *c_G, const CeedScalar *__restrict__ d_U,
123                                          CeedScalar *__restrict__ d_V) {
124   extern __shared__ CeedScalar slice[];
125 
126   SharedData_Cuda data;
127   data.t_id_x = threadIdx.x;
128   data.t_id_y = threadIdx.y;
129   data.t_id_z = threadIdx.z;
130   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
131   data.slice  = slice + data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
132 
133   CeedScalar r_U[BASIS_NUM_COMP * BASIS_DIM * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
134   CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
135 
136   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
137     if (BASIS_DIM == 1) {
138       ReadElementStrided1d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * num_elem, BASIS_Q_1D, d_U, r_U);
139       GradTranspose1d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
140       WriteElementStrided1d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * num_elem, BASIS_P_1D, r_V, d_V);
141     } else if (BASIS_DIM == 2) {
142       ReadElementStrided2d<BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * num_elem, BASIS_Q_1D * BASIS_Q_1D, d_U,
143                                                                    r_U);
144       GradTransposeTensor2d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
145       WriteElementStrided2d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * num_elem, BASIS_P_1D * BASIS_P_1D, r_V, d_V);
146     } else if (BASIS_DIM == 3) {
147       ReadElementStrided3d<BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D * num_elem,
148                                                                    BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D, d_U, r_U);
149       if (BASIS_HAS_COLLOCATED_GRAD) GradTransposeTensorCollocated3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
150       else GradTransposeTensor3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
151       WriteElementStrided3d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D * BASIS_P_1D * BASIS_P_1D * num_elem,
152                                                         BASIS_P_1D * BASIS_P_1D * BASIS_P_1D, r_V, d_V);
153     }
154   }
155 }
156 
157 //------------------------------------------------------------------------------
158 // Weight kernels by dim
159 //------------------------------------------------------------------------------
160 extern "C" __global__ void Weight(const CeedInt num_elem, const CeedScalar *__restrict__ q_weight_1d, CeedScalar *__restrict__ d_W) {
161   extern __shared__ CeedScalar slice[];
162 
163   SharedData_Cuda data;
164   data.t_id_x = threadIdx.x;
165   data.t_id_y = threadIdx.y;
166   data.t_id_z = threadIdx.z;
167   data.t_id   = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
168   data.slice  = slice + data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
169 
170   CeedScalar r_W[BASIS_DIM > 2 ? BASIS_Q_1D : 1];
171 
172   for (CeedInt elem = blockIdx.x * blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x * blockDim.z) {
173     if (BASIS_DIM == 1) {
174       Weight1d<BASIS_Q_1D>(data, q_weight_1d, r_W);
175       WriteElementStrided1d<1, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * num_elem, BASIS_Q_1D, r_W, d_W);
176     } else if (BASIS_DIM == 2) {
177       WeightTensor2d<BASIS_Q_1D>(data, q_weight_1d, r_W);
178       WriteElementStrided2d<1, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * num_elem, BASIS_Q_1D * BASIS_Q_1D, r_W, d_W);
179     } else if (BASIS_DIM == 3) {
180       WeightTensor3d<BASIS_Q_1D>(data, q_weight_1d, r_W);
181       WriteElementStrided3d<1, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D * num_elem, BASIS_Q_1D * BASIS_Q_1D * BASIS_Q_1D, r_W,
182                                            d_W);
183     }
184   }
185 }
186