xref: /libCEED/include/ceed/jit-source/cuda/cuda-shared-basis-tensor.h (revision 9e201c85545dd39529c090846df629a32c15659b)
1*9e201c85SYohann // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*9e201c85SYohann // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*9e201c85SYohann //
4*9e201c85SYohann // SPDX-License-Identifier: BSD-2-Clause
5*9e201c85SYohann //
6*9e201c85SYohann // This file is part of CEED:  http://github.com/ceed
7*9e201c85SYohann 
8*9e201c85SYohann /// @file
9*9e201c85SYohann /// Internal header for CUDA shared memory tensor product basis
10*9e201c85SYohann #ifndef _ceed_cuda_shared_basis_tensor_h
11*9e201c85SYohann #define _ceed_cuda_shared_basis_tensor_h
12*9e201c85SYohann 
13*9e201c85SYohann #include <ceed.h>
14*9e201c85SYohann #include "cuda-shared-basis-read-write-templates.h"
15*9e201c85SYohann #include "cuda-shared-basis-tensor-templates.h"
16*9e201c85SYohann 
17*9e201c85SYohann //------------------------------------------------------------------------------
18*9e201c85SYohann // Interp kernel by dim
19*9e201c85SYohann //------------------------------------------------------------------------------
20*9e201c85SYohann extern "C" __global__ void Interp(const CeedInt num_elem,
21*9e201c85SYohann                                 const CeedScalar *c_B,
22*9e201c85SYohann                                 const CeedScalar *__restrict__ d_U,
23*9e201c85SYohann                                 CeedScalar *__restrict__ d_V) {
24*9e201c85SYohann   extern __shared__ CeedScalar slice[];
25*9e201c85SYohann 
26*9e201c85SYohann   SharedData_Cuda data;
27*9e201c85SYohann   data.t_id_x = threadIdx.x;
28*9e201c85SYohann   data.t_id_y = threadIdx.y;
29*9e201c85SYohann   data.t_id_z = threadIdx.z;
30*9e201c85SYohann   data.t_id  = threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.y*blockDim.x;
31*9e201c85SYohann   data.slice = slice + data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
32*9e201c85SYohann 
33*9e201c85SYohann   CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
34*9e201c85SYohann   CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
35*9e201c85SYohann 
36*9e201c85SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x*blockDim.z) {
37*9e201c85SYohann     if (BASIS_DIM == 1) {
38*9e201c85SYohann       ReadElementStrided1d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D*num_elem, BASIS_P_1D, d_U, r_U);
39*9e201c85SYohann       Interp1d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, r_V);
40*9e201c85SYohann       WriteElementStrided1d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D*num_elem, BASIS_Q_1D, r_V, d_V);
41*9e201c85SYohann     } else if (BASIS_DIM == 2) {
42*9e201c85SYohann       ReadElementStrided2d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D*BASIS_P_1D*num_elem, BASIS_P_1D*BASIS_P_1D, d_U, r_U);
43*9e201c85SYohann       InterpTensor2d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, r_V);
44*9e201c85SYohann       WriteElementStrided2d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D*BASIS_Q_1D*num_elem, BASIS_Q_1D*BASIS_Q_1D, r_V, d_V);
45*9e201c85SYohann     } else if (BASIS_DIM == 3) {
46*9e201c85SYohann       ReadElementStrided3d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D*BASIS_P_1D*BASIS_P_1D*num_elem, BASIS_P_1D*BASIS_P_1D*BASIS_P_1D, d_U, r_U);
47*9e201c85SYohann       InterpTensor3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, r_V);
48*9e201c85SYohann       WriteElementStrided3d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D*BASIS_Q_1D*BASIS_Q_1D*num_elem, BASIS_Q_1D*BASIS_Q_1D*BASIS_Q_1D, r_V, d_V);
49*9e201c85SYohann     }
50*9e201c85SYohann   }
51*9e201c85SYohann }
52*9e201c85SYohann 
53*9e201c85SYohann extern "C" __global__ void InterpTranspose(const CeedInt num_elem,
54*9e201c85SYohann                                 const CeedScalar *c_B,
55*9e201c85SYohann                                 const CeedScalar *__restrict__ d_U,
56*9e201c85SYohann                                 CeedScalar *__restrict__ d_V) {
57*9e201c85SYohann   extern __shared__ CeedScalar slice[];
58*9e201c85SYohann 
59*9e201c85SYohann   SharedData_Cuda data;
60*9e201c85SYohann   data.t_id_x = threadIdx.x;
61*9e201c85SYohann   data.t_id_y = threadIdx.y;
62*9e201c85SYohann   data.t_id_z = threadIdx.z;
63*9e201c85SYohann   data.t_id  = threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.y*blockDim.x;
64*9e201c85SYohann   data.slice = slice + data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
65*9e201c85SYohann 
66*9e201c85SYohann   CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
67*9e201c85SYohann   CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
68*9e201c85SYohann 
69*9e201c85SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x*blockDim.z) {
70*9e201c85SYohann     if (BASIS_DIM == 1) {
71*9e201c85SYohann       ReadElementStrided1d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D*num_elem, BASIS_Q_1D, d_U, r_U);
72*9e201c85SYohann       InterpTranspose1d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, r_V);
73*9e201c85SYohann       WriteElementStrided1d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D*num_elem, BASIS_P_1D, r_V, d_V);
74*9e201c85SYohann     } else if (BASIS_DIM == 2) {
75*9e201c85SYohann       ReadElementStrided2d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D*BASIS_Q_1D*num_elem, BASIS_Q_1D*BASIS_Q_1D, d_U, r_U);
76*9e201c85SYohann       InterpTransposeTensor2d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, r_V);
77*9e201c85SYohann       WriteElementStrided2d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D*BASIS_P_1D*num_elem, BASIS_P_1D*BASIS_P_1D, r_V, d_V);
78*9e201c85SYohann     } else if (BASIS_DIM == 3) {
79*9e201c85SYohann       ReadElementStrided3d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D*BASIS_Q_1D*BASIS_Q_1D*num_elem, BASIS_Q_1D*BASIS_Q_1D*BASIS_Q_1D, d_U, r_U);
80*9e201c85SYohann       InterpTransposeTensor3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, r_V);
81*9e201c85SYohann       WriteElementStrided3d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D*BASIS_P_1D*BASIS_P_1D*num_elem, BASIS_P_1D*BASIS_P_1D*BASIS_P_1D, r_V, d_V);
82*9e201c85SYohann     }
83*9e201c85SYohann   }
84*9e201c85SYohann }
85*9e201c85SYohann 
86*9e201c85SYohann //------------------------------------------------------------------------------
87*9e201c85SYohann // Grad kernel by dim
88*9e201c85SYohann //------------------------------------------------------------------------------
89*9e201c85SYohann extern "C" __global__ void Grad(const CeedInt num_elem,
90*9e201c85SYohann                                 const CeedScalar *c_B, const CeedScalar *c_G,
91*9e201c85SYohann                                 const CeedScalar *__restrict__ d_U,
92*9e201c85SYohann                                 CeedScalar *__restrict__ d_V) {
93*9e201c85SYohann   extern __shared__ CeedScalar slice[];
94*9e201c85SYohann 
95*9e201c85SYohann   SharedData_Cuda data;
96*9e201c85SYohann   data.t_id_x = threadIdx.x;
97*9e201c85SYohann   data.t_id_y = threadIdx.y;
98*9e201c85SYohann   data.t_id_z = threadIdx.z;
99*9e201c85SYohann   data.t_id  = threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.y*blockDim.x;
100*9e201c85SYohann   data.slice = slice + data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
101*9e201c85SYohann 
102*9e201c85SYohann   CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
103*9e201c85SYohann   CeedScalar r_V[BASIS_NUM_COMP * BASIS_DIM * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
104*9e201c85SYohann 
105*9e201c85SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x*blockDim.z) {
106*9e201c85SYohann     if (BASIS_DIM == 1) {
107*9e201c85SYohann       ReadElementStrided1d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D*num_elem, BASIS_P_1D, d_U, r_U);
108*9e201c85SYohann       Grad1d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
109*9e201c85SYohann       WriteElementStrided1d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D*num_elem, BASIS_Q_1D, r_V, d_V);
110*9e201c85SYohann     } else if (BASIS_DIM == 2) {
111*9e201c85SYohann       ReadElementStrided2d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D*BASIS_P_1D*num_elem, BASIS_P_1D*BASIS_P_1D, d_U, r_U);
112*9e201c85SYohann       GradTensor2d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
113*9e201c85SYohann       WriteElementStrided2d<BASIS_NUM_COMP*BASIS_DIM, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D*BASIS_Q_1D*num_elem, BASIS_Q_1D*BASIS_Q_1D, r_V, d_V);
114*9e201c85SYohann     } else if (BASIS_DIM == 3) {
115*9e201c85SYohann       ReadElementStrided3d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D*BASIS_P_1D*BASIS_P_1D*num_elem, BASIS_P_1D*BASIS_P_1D*BASIS_P_1D, d_U, r_U);
116*9e201c85SYohann       if (BASIS_HAS_COLLOCATED_GRAD) GradTensorCollocated3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
117*9e201c85SYohann       else                           GradTensor3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
118*9e201c85SYohann       WriteElementStrided3d<BASIS_NUM_COMP*BASIS_DIM, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D*BASIS_Q_1D*BASIS_Q_1D*num_elem, BASIS_Q_1D*BASIS_Q_1D*BASIS_Q_1D, r_V, d_V);
119*9e201c85SYohann     }
120*9e201c85SYohann   }
121*9e201c85SYohann }
122*9e201c85SYohann 
123*9e201c85SYohann extern "C" __global__ void GradTranspose(const CeedInt num_elem,
124*9e201c85SYohann                                 const CeedScalar *c_B, const CeedScalar *c_G,
125*9e201c85SYohann                                 const CeedScalar *__restrict__ d_U,
126*9e201c85SYohann                                 CeedScalar *__restrict__ d_V) {
127*9e201c85SYohann   extern __shared__ CeedScalar slice[];
128*9e201c85SYohann 
129*9e201c85SYohann   SharedData_Cuda data;
130*9e201c85SYohann   data.t_id_x = threadIdx.x;
131*9e201c85SYohann   data.t_id_y = threadIdx.y;
132*9e201c85SYohann   data.t_id_z = threadIdx.z;
133*9e201c85SYohann   data.t_id  = threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.y*blockDim.x;
134*9e201c85SYohann   data.slice = slice + data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
135*9e201c85SYohann 
136*9e201c85SYohann   CeedScalar r_U[BASIS_NUM_COMP * BASIS_DIM * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
137*9e201c85SYohann   CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
138*9e201c85SYohann 
139*9e201c85SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x*blockDim.z) {
140*9e201c85SYohann     if (BASIS_DIM == 1) {
141*9e201c85SYohann       ReadElementStrided1d<BASIS_NUM_COMP, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D*num_elem, BASIS_Q_1D, d_U, r_U);
142*9e201c85SYohann       GradTranspose1d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
143*9e201c85SYohann       WriteElementStrided1d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D*num_elem, BASIS_P_1D, r_V, d_V);
144*9e201c85SYohann     } else if (BASIS_DIM == 2) {
145*9e201c85SYohann       ReadElementStrided2d<BASIS_NUM_COMP*BASIS_DIM, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D*BASIS_Q_1D*num_elem, BASIS_Q_1D*BASIS_Q_1D, d_U, r_U);
146*9e201c85SYohann       GradTransposeTensor2d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
147*9e201c85SYohann       WriteElementStrided2d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D*BASIS_P_1D*num_elem, BASIS_P_1D*BASIS_P_1D, r_V, d_V);
148*9e201c85SYohann     } else if (BASIS_DIM == 3) {
149*9e201c85SYohann       ReadElementStrided3d<BASIS_NUM_COMP*BASIS_DIM, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D*BASIS_Q_1D*BASIS_Q_1D*num_elem, BASIS_Q_1D*BASIS_Q_1D*BASIS_Q_1D, d_U, r_U);
150*9e201c85SYohann       if (BASIS_HAS_COLLOCATED_GRAD) GradTransposeTensorCollocated3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
151*9e201c85SYohann       else                           GradTransposeTensor3d<BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D>(data, r_U, c_B, c_G, r_V);
152*9e201c85SYohann       WriteElementStrided3d<BASIS_NUM_COMP, BASIS_P_1D>(data, elem, 1, BASIS_P_1D*BASIS_P_1D*BASIS_P_1D*num_elem, BASIS_P_1D*BASIS_P_1D*BASIS_P_1D, r_V, d_V);
153*9e201c85SYohann     }
154*9e201c85SYohann   }
155*9e201c85SYohann }
156*9e201c85SYohann 
157*9e201c85SYohann //------------------------------------------------------------------------------
158*9e201c85SYohann // Weight kernels by dim
159*9e201c85SYohann //------------------------------------------------------------------------------
160*9e201c85SYohann extern "C" __global__ void Weight(const CeedInt num_elem,
161*9e201c85SYohann                                   const CeedScalar *__restrict__ q_weight_1d,
162*9e201c85SYohann                                   CeedScalar *__restrict__ d_W) {
163*9e201c85SYohann   extern __shared__ CeedScalar slice[];
164*9e201c85SYohann 
165*9e201c85SYohann   SharedData_Cuda data;
166*9e201c85SYohann   data.t_id_x = threadIdx.x;
167*9e201c85SYohann   data.t_id_y = threadIdx.y;
168*9e201c85SYohann   data.t_id_z = threadIdx.z;
169*9e201c85SYohann   data.t_id  = threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.y*blockDim.x;
170*9e201c85SYohann   data.slice = slice + data.t_id_z * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
171*9e201c85SYohann 
172*9e201c85SYohann   CeedScalar r_W[BASIS_DIM > 2 ? BASIS_Q_1D : 1];
173*9e201c85SYohann 
174*9e201c85SYohann   for (CeedInt elem = blockIdx.x*blockDim.z + threadIdx.z; elem < num_elem; elem += gridDim.x*blockDim.z) {
175*9e201c85SYohann     if (BASIS_DIM == 1) {
176*9e201c85SYohann       Weight1d<BASIS_Q_1D>(data, q_weight_1d, r_W);
177*9e201c85SYohann       WriteElementStrided1d<1, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D*num_elem, BASIS_Q_1D, r_W, d_W);
178*9e201c85SYohann     } else if (BASIS_DIM == 2) {
179*9e201c85SYohann       WeightTensor2d<BASIS_Q_1D>(data, q_weight_1d, r_W);
180*9e201c85SYohann       WriteElementStrided2d<1, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D*BASIS_Q_1D*num_elem, BASIS_Q_1D*BASIS_Q_1D, r_W, d_W);
181*9e201c85SYohann     } else if (BASIS_DIM == 3) {
182*9e201c85SYohann       WeightTensor3d<BASIS_Q_1D>(data, q_weight_1d, r_W);
183*9e201c85SYohann       WriteElementStrided3d<1, BASIS_Q_1D>(data, elem, 1, BASIS_Q_1D*BASIS_Q_1D*BASIS_Q_1D*num_elem, BASIS_Q_1D*BASIS_Q_1D*BASIS_Q_1D, r_W, d_W);
184*9e201c85SYohann     }
185*9e201c85SYohann   }
186*9e201c85SYohann }
187*9e201c85SYohann 
188*9e201c85SYohann //------------------------------------------------------------------------------
189*9e201c85SYohann 
190*9e201c85SYohann #endif
191