xref: /libCEED/include/ceed/jit-source/cuda/cuda-ref-basis-tensor-at-points.h (revision 14950a8eea941c036fb81fbb2249468a1035cf45)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for CUDA tensor product basis with AtPoints evaluation
10 
11 #include <ceed.h>
12 
13 //------------------------------------------------------------------------------
14 // Chebyshev values
15 //------------------------------------------------------------------------------
16 template <int Q_1D>
17 inline __device__ void ChebyshevPolynomialsAtPoint(const CeedScalar x, CeedScalar *chebyshev_x) {
18   chebyshev_x[0] = 1.0;
19   chebyshev_x[1] = 2 * x;
20   for (CeedInt i = 2; i < Q_1D; i++) chebyshev_x[i] = 2 * x * chebyshev_x[i - 1] - chebyshev_x[i - 2];
21 }
22 
23 template <int Q_1D>
24 inline __device__ void ChebyshevDerivativeAtPoint(const CeedScalar x, CeedScalar *chebyshev_dx) {
25   CeedScalar chebyshev_x[3];
26 
27   chebyshev_x[1]  = 1.0;
28   chebyshev_x[2]  = 2 * x;
29   chebyshev_dx[0] = 0.0;
30   chebyshev_dx[1] = 2.0;
31   for (CeedInt i = 2; i < Q_1D; i++) {
32     chebyshev_x[0]  = chebyshev_x[1];
33     chebyshev_x[1]  = chebyshev_x[2];
34     chebyshev_x[2]  = 2 * x * chebyshev_x[1] - chebyshev_x[0];
35     chebyshev_dx[i] = 2 * x * chebyshev_dx[i - 1] + 2 * chebyshev_x[1] - chebyshev_dx[i - 2];
36   }
37 }
38 
39 //------------------------------------------------------------------------------
40 // Tensor Basis Kernels AtPoints
41 //------------------------------------------------------------------------------
42 
43 //------------------------------------------------------------------------------
44 // Interp
45 //------------------------------------------------------------------------------
46 extern "C" __global__ void InterpAtPoints(const CeedInt num_elem, const CeedInt is_transpose, const CeedScalar *__restrict__ chebyshev_interp_1d,
47                                           const CeedScalar *__restrict__ coords, const CeedScalar *__restrict__ u, CeedScalar *__restrict__ v) {
48   const CeedInt i = threadIdx.x;
49 
50   __shared__ CeedScalar s_mem[BASIS_Q_1D * BASIS_P_1D + 2 * BASIS_BUF_LEN + POINTS_BUFF_LEN * BASIS_Q_1D];
51   CeedScalar           *s_chebyshev_interp_1d = s_mem;
52   CeedScalar           *s_buffer_1            = s_mem + BASIS_Q_1D * BASIS_P_1D;
53   CeedScalar           *s_buffer_2            = s_buffer_1 + BASIS_BUF_LEN;
54   CeedScalar           *s_chebyshev_coeffs    = s_buffer_2 + BASIS_BUF_LEN;
55   CeedScalar            chebyshev_x[BASIS_Q_1D], buffer_1[POINTS_BUFF_LEN], buffer_2[POINTS_BUFF_LEN];
56   for (CeedInt k = i; k < BASIS_Q_1D * BASIS_P_1D; k += blockDim.x) {
57     s_chebyshev_interp_1d[k] = chebyshev_interp_1d[k];
58   }
59 
60   const CeedInt P             = BASIS_P_1D;
61   const CeedInt Q             = BASIS_Q_1D;
62   const CeedInt u_stride      = is_transpose ? BASIS_NUM_PTS : BASIS_NUM_NODES;
63   const CeedInt v_stride      = is_transpose ? BASIS_NUM_NODES : BASIS_NUM_PTS;
64   const CeedInt u_comp_stride = num_elem * (is_transpose ? BASIS_NUM_PTS : BASIS_NUM_NODES);
65   const CeedInt v_comp_stride = num_elem * (is_transpose ? BASIS_NUM_NODES : BASIS_NUM_PTS);
66   const CeedInt u_size        = is_transpose ? BASIS_NUM_PTS : BASIS_NUM_NODES;
67 
68   // Apply basis element by element
69   if (is_transpose) {
70     for (CeedInt elem = blockIdx.x; elem < num_elem; elem += gridDim.x) {
71       for (CeedInt comp = 0; comp < BASIS_NUM_COMP; comp++) {
72         const CeedScalar *cur_u = u + elem * u_stride + comp * u_comp_stride;
73         CeedScalar       *cur_v = v + elem * v_stride + comp * v_comp_stride;
74         CeedInt           pre   = 1;
75         CeedInt           post  = 1;
76 
77         // Clear Chebyshev coeffs
78         for (CeedInt k = i; k < BASIS_NUM_QPTS; k += blockDim.x) {
79           s_chebyshev_coeffs[k] = 0.0;
80         }
81 
82         // Map from point
83         __syncthreads();
84         for (CeedInt p = threadIdx.x; p < BASIS_NUM_PTS; p += blockDim.x) {
85           pre  = 1;
86           post = 1;
87           for (CeedInt d = 0; d < BASIS_DIM; d++) {
88             // Update buffers used
89             pre /= 1;
90             const CeedScalar *in  = d == 0 ? (cur_u + p) : (d % 2 ? buffer_2 : buffer_1);
91             CeedScalar       *out = d == BASIS_DIM - 1 ? s_chebyshev_coeffs : (d % 2 ? buffer_1 : buffer_2);
92 
93             // Build Chebyshev polynomial values
94             ChebyshevPolynomialsAtPoint<BASIS_Q_1D>(coords[elem * u_stride + d * u_comp_stride + p], chebyshev_x);
95 
96             // Contract along middle index
97             for (CeedInt a = 0; a < pre; a++) {
98               for (CeedInt c = 0; c < post; c++) {
99                 if (d == BASIS_DIM - 1) {
100                   for (CeedInt j = 0; j < Q; j++) atomicAdd(&out[(a * Q + j) * post + c], chebyshev_x[j] * in[a * post + c]);
101                 } else {
102                   for (CeedInt j = 0; j < Q; j++) out[(a * Q + j) * post + c] = chebyshev_x[j] * in[a * post + c];
103                 }
104               }
105             }
106             post *= Q;
107           }
108         }
109 
110         // Map from coefficients
111         pre  = BASIS_NUM_QPTS;
112         post = 1;
113         for (CeedInt d = 0; d < BASIS_DIM; d++) {
114           __syncthreads();
115           // Update buffers used
116           pre /= Q;
117           const CeedScalar *in       = d == 0 ? s_chebyshev_coeffs : (d % 2 ? s_buffer_2 : s_buffer_1);
118           CeedScalar       *out      = d == BASIS_DIM - 1 ? cur_v : (d % 2 ? s_buffer_1 : s_buffer_2);
119           const CeedInt     writeLen = pre * post * P;
120 
121           // Contract along middle index
122           for (CeedInt k = i; k < writeLen; k += blockDim.x) {
123             const CeedInt c   = k % post;
124             const CeedInt j   = (k / post) % P;
125             const CeedInt a   = k / (post * P);
126             CeedScalar    v_k = 0;
127 
128             for (CeedInt b = 0; b < Q; b++) v_k += s_chebyshev_interp_1d[j + b * BASIS_P_1D] * in[(a * Q + b) * post + c];
129             out[k] = v_k;
130           }
131           post *= P;
132         }
133       }
134     }
135   } else {
136     for (CeedInt elem = blockIdx.x; elem < num_elem; elem += gridDim.x) {
137       for (CeedInt comp = 0; comp < BASIS_NUM_COMP; comp++) {
138         const CeedScalar *cur_u = u + elem * u_stride + comp * u_comp_stride;
139         CeedScalar       *cur_v = v + elem * v_stride + comp * v_comp_stride;
140         CeedInt           pre   = u_size;
141         CeedInt           post  = 1;
142 
143         // Map to coefficients
144         for (CeedInt d = 0; d < BASIS_DIM; d++) {
145           __syncthreads();
146           // Update buffers used
147           pre /= P;
148           const CeedScalar *in       = d == 0 ? cur_u : (d % 2 ? s_buffer_2 : s_buffer_1);
149           CeedScalar       *out      = d == BASIS_DIM - 1 ? s_chebyshev_coeffs : (d % 2 ? s_buffer_1 : s_buffer_2);
150           const CeedInt     writeLen = pre * post * Q;
151 
152           // Contract along middle index
153           for (CeedInt k = i; k < writeLen; k += blockDim.x) {
154             const CeedInt c   = k % post;
155             const CeedInt j   = (k / post) % Q;
156             const CeedInt a   = k / (post * Q);
157             CeedScalar    v_k = 0;
158 
159             for (CeedInt b = 0; b < P; b++) v_k += s_chebyshev_interp_1d[j * BASIS_P_1D + b] * in[(a * P + b) * post + c];
160             out[k] = v_k;
161           }
162           post *= Q;
163         }
164 
165         // Map to point
166         __syncthreads();
167         for (CeedInt p = threadIdx.x; p < BASIS_NUM_PTS; p += blockDim.x) {
168           pre  = BASIS_NUM_QPTS;
169           post = 1;
170           for (CeedInt d = 0; d < BASIS_DIM; d++) {
171             // Update buffers used
172             pre /= Q;
173             const CeedScalar *in  = d == 0 ? s_chebyshev_coeffs : (d % 2 ? buffer_2 : buffer_1);
174             CeedScalar       *out = d == BASIS_DIM - 1 ? (cur_v + p) : (d % 2 ? buffer_1 : buffer_2);
175 
176             // Build Chebyshev polynomial values
177             ChebyshevPolynomialsAtPoint<BASIS_Q_1D>(coords[elem * v_stride + d * v_comp_stride + p], chebyshev_x);
178 
179             // Contract along middle index
180             for (CeedInt a = 0; a < pre; a++) {
181               for (CeedInt c = 0; c < post; c++) {
182                 CeedScalar v_k = 0;
183 
184                 for (CeedInt b = 0; b < Q; b++) v_k += chebyshev_x[b] * in[(a * Q + b) * post + c];
185                 out[a * post + c] = v_k;
186               }
187             }
188             post *= 1;
189           }
190         }
191       }
192     }
193   }
194 }
195 
196 //------------------------------------------------------------------------------
197 // Grad
198 //------------------------------------------------------------------------------
199 extern "C" __global__ void GradAtPoints(const CeedInt num_elem, const CeedInt is_transpose, const CeedScalar *__restrict__ chebyshev_interp_1d,
200                                         const CeedScalar *__restrict__ coords, const CeedScalar *__restrict__ u, CeedScalar *__restrict__ v) {
201   const CeedInt i = threadIdx.x;
202 
203   __shared__ CeedScalar s_mem[BASIS_Q_1D * BASIS_P_1D + 2 * BASIS_BUF_LEN + POINTS_BUFF_LEN * BASIS_Q_1D];
204   CeedScalar           *s_chebyshev_interp_1d = s_mem;
205   CeedScalar           *s_buffer_1            = s_mem + BASIS_Q_1D * BASIS_P_1D;
206   CeedScalar           *s_buffer_2            = s_buffer_1 + BASIS_BUF_LEN;
207   CeedScalar           *s_chebyshev_coeffs    = s_buffer_2 + BASIS_BUF_LEN;
208   CeedScalar            chebyshev_x[BASIS_Q_1D], buffer_1[POINTS_BUFF_LEN], buffer_2[POINTS_BUFF_LEN];
209   for (CeedInt k = i; k < BASIS_Q_1D * BASIS_P_1D; k += blockDim.x) {
210     s_chebyshev_interp_1d[k] = chebyshev_interp_1d[k];
211   }
212 
213   const CeedInt P             = BASIS_P_1D;
214   const CeedInt Q             = BASIS_Q_1D;
215   const CeedInt u_stride      = is_transpose ? BASIS_NUM_PTS : BASIS_NUM_NODES;
216   const CeedInt v_stride      = is_transpose ? BASIS_NUM_NODES : BASIS_NUM_PTS;
217   const CeedInt u_comp_stride = num_elem * (is_transpose ? BASIS_NUM_PTS : BASIS_NUM_NODES);
218   const CeedInt v_comp_stride = num_elem * (is_transpose ? BASIS_NUM_NODES : BASIS_NUM_PTS);
219   const CeedInt u_size        = is_transpose ? BASIS_NUM_PTS : BASIS_NUM_NODES;
220   const CeedInt u_dim_stride  = is_transpose ? num_elem * BASIS_NUM_PTS * BASIS_NUM_COMP : 0;
221   const CeedInt v_dim_stride  = is_transpose ? 0 : num_elem * BASIS_NUM_PTS * BASIS_NUM_COMP;
222 
223   // Apply basis element by element
224   if (is_transpose) {
225     for (CeedInt elem = blockIdx.x; elem < num_elem; elem += gridDim.x) {
226       for (CeedInt comp = 0; comp < BASIS_NUM_COMP; comp++) {
227         CeedScalar *cur_v = v + elem * v_stride + comp * v_comp_stride;
228         CeedInt     pre   = 1;
229         CeedInt     post  = 1;
230 
231         // Clear Chebyshev coeffs
232         for (CeedInt k = i; k < BASIS_NUM_QPTS; k += blockDim.x) {
233           s_chebyshev_coeffs[k] = 0.0;
234         }
235 
236         // Map from point
237         __syncthreads();
238         for (CeedInt p = threadIdx.x; p < BASIS_NUM_PTS; p += blockDim.x) {
239           for (CeedInt dim_1 = 0; dim_1 < BASIS_DIM; dim_1++) {
240             const CeedScalar *cur_u = u + elem * u_stride + dim_1 * u_dim_stride + comp * u_comp_stride;
241 
242             pre  = 1;
243             post = 1;
244             for (CeedInt dim_2 = 0; dim_2 < BASIS_DIM; dim_2++) {
245               // Update buffers used
246               pre /= 1;
247               const CeedScalar *in  = dim_2 == 0 ? (cur_u + p) : (dim_2 % 2 ? buffer_2 : buffer_1);
248               CeedScalar       *out = dim_2 == BASIS_DIM - 1 ? s_chebyshev_coeffs : (dim_2 % 2 ? buffer_1 : buffer_2);
249 
250               // Build Chebyshev polynomial values
251               if (dim_1 == dim_2) ChebyshevDerivativeAtPoint<BASIS_Q_1D>(coords[elem * u_stride + dim_2 * u_comp_stride + p], chebyshev_x);
252               else ChebyshevPolynomialsAtPoint<BASIS_Q_1D>(coords[elem * u_stride + dim_2 * u_comp_stride + p], chebyshev_x);
253 
254               // Contract along middle index
255               for (CeedInt a = 0; a < pre; a++) {
256                 for (CeedInt c = 0; c < post; c++) {
257                   if (dim_2 == BASIS_DIM - 1) {
258                     for (CeedInt j = 0; j < Q; j++) atomicAdd(&out[(a * Q + j) * post + c], chebyshev_x[j] * in[a * post + c]);
259                   } else {
260                     for (CeedInt j = 0; j < Q; j++) out[(a * Q + j) * post + c] = chebyshev_x[j] * in[a * post + c];
261                   }
262                 }
263               }
264               post *= Q;
265             }
266           }
267         }
268 
269         // Map from coefficients
270         pre  = BASIS_NUM_QPTS;
271         post = 1;
272         for (CeedInt d = 0; d < BASIS_DIM; d++) {
273           __syncthreads();
274           // Update buffers used
275           pre /= Q;
276           const CeedScalar *in       = d == 0 ? s_chebyshev_coeffs : (d % 2 ? s_buffer_2 : s_buffer_1);
277           CeedScalar       *out      = d == BASIS_DIM - 1 ? cur_v : (d % 2 ? s_buffer_1 : s_buffer_2);
278           const CeedInt     writeLen = pre * post * P;
279 
280           // Contract along middle index
281           for (CeedInt k = i; k < writeLen; k += blockDim.x) {
282             const CeedInt c   = k % post;
283             const CeedInt j   = (k / post) % P;
284             const CeedInt a   = k / (post * P);
285             CeedScalar    v_k = 0;
286 
287             for (CeedInt b = 0; b < Q; b++) v_k += s_chebyshev_interp_1d[j + b * BASIS_P_1D] * in[(a * Q + b) * post + c];
288             out[k] = v_k;
289           }
290           post *= P;
291         }
292       }
293     }
294   } else {
295     for (CeedInt elem = blockIdx.x; elem < num_elem; elem += gridDim.x) {
296       for (CeedInt comp = 0; comp < BASIS_NUM_COMP; comp++) {
297         const CeedScalar *cur_u = u + elem * u_stride + comp * u_comp_stride;
298         CeedInt           pre   = u_size;
299         CeedInt           post  = 1;
300 
301         // Map to coefficients
302         for (CeedInt d = 0; d < BASIS_DIM; d++) {
303           __syncthreads();
304           // Update buffers used
305           pre /= P;
306           const CeedScalar *in       = d == 0 ? cur_u : (d % 2 ? s_buffer_2 : s_buffer_1);
307           CeedScalar       *out      = d == BASIS_DIM - 1 ? s_chebyshev_coeffs : (d % 2 ? s_buffer_1 : s_buffer_2);
308           const CeedInt     writeLen = pre * post * Q;
309 
310           // Contract along middle index
311           for (CeedInt k = i; k < writeLen; k += blockDim.x) {
312             const CeedInt c   = k % post;
313             const CeedInt j   = (k / post) % Q;
314             const CeedInt a   = k / (post * Q);
315             CeedScalar    v_k = 0;
316 
317             for (CeedInt b = 0; b < P; b++) v_k += s_chebyshev_interp_1d[j * BASIS_P_1D + b] * in[(a * P + b) * post + c];
318             out[k] = v_k;
319           }
320           post *= Q;
321         }
322 
323         // Map to point
324         __syncthreads();
325         for (CeedInt p = threadIdx.x; p < BASIS_NUM_PTS; p += blockDim.x) {
326           for (CeedInt dim_1 = 0; dim_1 < BASIS_DIM; dim_1++) {
327             CeedScalar *cur_v = v + elem * v_stride + dim_1 * v_dim_stride + comp * v_comp_stride;
328 
329             pre  = BASIS_NUM_QPTS;
330             post = 1;
331             for (CeedInt dim_2 = 0; dim_2 < BASIS_DIM; dim_2++) {
332               // Update buffers used
333               pre /= Q;
334               const CeedScalar *in  = dim_2 == 0 ? s_chebyshev_coeffs : (dim_2 % 2 ? buffer_2 : buffer_1);
335               CeedScalar       *out = dim_2 == BASIS_DIM - 1 ? (cur_v + p) : (dim_2 % 2 ? buffer_1 : buffer_2);
336 
337               // Build Chebyshev polynomial values
338               if (dim_1 == dim_2) ChebyshevDerivativeAtPoint<BASIS_Q_1D>(coords[elem * v_stride + dim_2 * v_comp_stride + p], chebyshev_x);
339               else ChebyshevPolynomialsAtPoint<BASIS_Q_1D>(coords[elem * v_stride + dim_2 * v_comp_stride + p], chebyshev_x);
340 
341               // Contract along middle index
342               for (CeedInt a = 0; a < pre; a++) {
343                 for (CeedInt c = 0; c < post; c++) {
344                   CeedScalar v_k = 0;
345 
346                   for (CeedInt b = 0; b < Q; b++) v_k += chebyshev_x[b] * in[(a * Q + b) * post + c];
347                   out[a * post + c] = v_k;
348                 }
349               }
350               post *= 1;
351             }
352           }
353         }
354       }
355     }
356   }
357 }
358