xref: /libCEED/include/ceed/jit-source/cuda/cuda-ref-basis-tensor-at-points.h (revision 80c135a87dd608e39d180e7bb5c260aa9fcc10a1)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for CUDA tensor product basis with AtPoints evaluation
10 
11 #include <ceed.h>
12 
13 //------------------------------------------------------------------------------
14 // Chebyshev values
15 //------------------------------------------------------------------------------
16 template <int Q_1D>
17 inline __device__ void ChebyshevPolynomialsAtPoint(const CeedScalar x, CeedScalar *chebyshev_x) {
18   chebyshev_x[0] = 1.0;
19   chebyshev_x[1] = 2 * x;
20   for (CeedInt i = 2; i < Q_1D; i++) chebyshev_x[i] = 2 * x * chebyshev_x[i - 1] - chebyshev_x[i - 2];
21 }
22 
23 template <int Q_1D>
24 inline __device__ void ChebyshevDerivativeAtPoint(const CeedScalar x, CeedScalar *chebyshev_dx) {
25   CeedScalar chebyshev_x[3];
26 
27   chebyshev_x[1]  = 1.0;
28   chebyshev_x[2]  = 2 * x;
29   chebyshev_dx[0] = 0.0;
30   chebyshev_dx[1] = 2.0;
31   for (CeedInt i = 2; i < Q_1D; i++) {
32     chebyshev_x[(i + 1) % 3] = 2 * x * chebyshev_x[(i + 0) % 3] - chebyshev_x[(i + 2) % 3];
33     chebyshev_dx[i]          = 2 * x * chebyshev_dx[i - 1] + 2 * chebyshev_x[(i + 0) % 3] - chebyshev_dx[i - 2];
34   }
35 }
36 
37 //------------------------------------------------------------------------------
38 // Tensor Basis Kernels AtPoints
39 //------------------------------------------------------------------------------
40 
41 //------------------------------------------------------------------------------
42 // Interp
43 //------------------------------------------------------------------------------
44 extern "C" __global__ void InterpAtPoints(const CeedInt num_elem, const CeedInt is_transpose, const CeedScalar *__restrict__ chebyshev_interp_1d,
45                                           const CeedScalar *__restrict__ coords, const CeedScalar *__restrict__ u, CeedScalar *__restrict__ v) {
46   const CeedInt i = threadIdx.x;
47 
48   __shared__ CeedScalar s_mem[BASIS_Q_1D * BASIS_P_1D + 2 * BASIS_BUF_LEN + POINTS_BUFF_LEN * BASIS_Q_1D];
49   CeedScalar           *s_chebyshev_interp_1d = s_mem;
50   CeedScalar           *s_buffer_1            = s_mem + BASIS_Q_1D * BASIS_P_1D;
51   CeedScalar           *s_buffer_2            = s_buffer_1 + BASIS_BUF_LEN;
52   CeedScalar           *s_chebyshev_coeffs    = s_buffer_2 + BASIS_BUF_LEN;
53   CeedScalar            chebyshev_x[BASIS_Q_1D], buffer_1[POINTS_BUFF_LEN], buffer_2[POINTS_BUFF_LEN];
54   for (CeedInt k = i; k < BASIS_Q_1D * BASIS_P_1D; k += blockDim.x) {
55     s_chebyshev_interp_1d[k] = chebyshev_interp_1d[k];
56   }
57 
58   const CeedInt P             = BASIS_P_1D;
59   const CeedInt Q             = BASIS_Q_1D;
60   const CeedInt u_stride      = is_transpose ? BASIS_NUM_PTS : BASIS_NUM_NODES;
61   const CeedInt v_stride      = is_transpose ? BASIS_NUM_NODES : BASIS_NUM_PTS;
62   const CeedInt u_comp_stride = num_elem * (is_transpose ? BASIS_NUM_PTS : BASIS_NUM_NODES);
63   const CeedInt v_comp_stride = num_elem * (is_transpose ? BASIS_NUM_NODES : BASIS_NUM_PTS);
64   const CeedInt u_size        = is_transpose ? BASIS_NUM_PTS : BASIS_NUM_NODES;
65 
66   // Apply basis element by element
67   if (is_transpose) {
68     for (CeedInt elem = blockIdx.x; elem < num_elem; elem += gridDim.x) {
69       for (CeedInt comp = 0; comp < BASIS_NUM_COMP; comp++) {
70         const CeedScalar *cur_u = u + elem * u_stride + comp * u_comp_stride;
71         CeedScalar       *cur_v = v + elem * v_stride + comp * v_comp_stride;
72         CeedInt           pre   = 1;
73         CeedInt           post  = 1;
74 
75         // Clear Chebyshev coeffs
76         for (CeedInt k = i; k < BASIS_NUM_QPTS; k += blockDim.x) {
77           s_chebyshev_coeffs[k] = 0.0;
78         }
79 
80         // Map from point
81         __syncthreads();
82         for (CeedInt p = threadIdx.x; p < BASIS_NUM_PTS; p += blockDim.x) {
83           pre  = 1;
84           post = 1;
85           for (CeedInt d = 0; d < BASIS_DIM; d++) {
86             // Update buffers used
87             pre /= 1;
88             const CeedScalar *in  = d == 0 ? (cur_u + p) : (d % 2 ? buffer_2 : buffer_1);
89             CeedScalar       *out = d == BASIS_DIM - 1 ? s_chebyshev_coeffs : (d % 2 ? buffer_1 : buffer_2);
90 
91             // Build Chebyshev polynomial values
92             ChebyshevPolynomialsAtPoint<BASIS_Q_1D>(coords[elem * u_stride + d * u_comp_stride + p], chebyshev_x);
93 
94             // Contract along middle index
95             for (CeedInt a = 0; a < pre; a++) {
96               for (CeedInt c = 0; c < post; c++) {
97                 if (d == BASIS_DIM - 1) {
98                   for (CeedInt j = 0; j < Q; j++) atomicAdd(&out[(a * Q + (j + p) % Q) * post + c], chebyshev_x[(j + p) % Q] * in[a * post + c]);
99                 } else {
100                   for (CeedInt j = 0; j < Q; j++) out[(a * Q + j) * post + c] = chebyshev_x[j] * in[a * post + c];
101                 }
102               }
103             }
104             post *= Q;
105           }
106         }
107 
108         // Map from coefficients
109         pre  = BASIS_NUM_QPTS;
110         post = 1;
111         for (CeedInt d = 0; d < BASIS_DIM; d++) {
112           __syncthreads();
113           // Update buffers used
114           pre /= Q;
115           const CeedScalar *in       = d == 0 ? s_chebyshev_coeffs : (d % 2 ? s_buffer_2 : s_buffer_1);
116           CeedScalar       *out      = d == BASIS_DIM - 1 ? cur_v : (d % 2 ? s_buffer_1 : s_buffer_2);
117           const CeedInt     writeLen = pre * post * P;
118 
119           // Contract along middle index
120           for (CeedInt k = i; k < writeLen; k += blockDim.x) {
121             const CeedInt c   = k % post;
122             const CeedInt j   = (k / post) % P;
123             const CeedInt a   = k / (post * P);
124             CeedScalar    v_k = 0;
125 
126             for (CeedInt b = 0; b < Q; b++) v_k += s_chebyshev_interp_1d[j + b * BASIS_P_1D] * in[(a * Q + b) * post + c];
127             out[k] = v_k;
128           }
129           post *= P;
130         }
131       }
132     }
133   } else {
134     for (CeedInt elem = blockIdx.x; elem < num_elem; elem += gridDim.x) {
135       for (CeedInt comp = 0; comp < BASIS_NUM_COMP; comp++) {
136         const CeedScalar *cur_u = u + elem * u_stride + comp * u_comp_stride;
137         CeedScalar       *cur_v = v + elem * v_stride + comp * v_comp_stride;
138         CeedInt           pre   = u_size;
139         CeedInt           post  = 1;
140 
141         // Map to coefficients
142         for (CeedInt d = 0; d < BASIS_DIM; d++) {
143           __syncthreads();
144           // Update buffers used
145           pre /= P;
146           const CeedScalar *in       = d == 0 ? cur_u : (d % 2 ? s_buffer_2 : s_buffer_1);
147           CeedScalar       *out      = d == BASIS_DIM - 1 ? s_chebyshev_coeffs : (d % 2 ? s_buffer_1 : s_buffer_2);
148           const CeedInt     writeLen = pre * post * Q;
149 
150           // Contract along middle index
151           for (CeedInt k = i; k < writeLen; k += blockDim.x) {
152             const CeedInt c   = k % post;
153             const CeedInt j   = (k / post) % Q;
154             const CeedInt a   = k / (post * Q);
155             CeedScalar    v_k = 0;
156 
157             for (CeedInt b = 0; b < P; b++) v_k += s_chebyshev_interp_1d[j * BASIS_P_1D + b] * in[(a * P + b) * post + c];
158             out[k] = v_k;
159           }
160           post *= Q;
161         }
162 
163         // Map to point
164         __syncthreads();
165         for (CeedInt p = threadIdx.x; p < BASIS_NUM_PTS; p += blockDim.x) {
166           pre  = BASIS_NUM_QPTS;
167           post = 1;
168           for (CeedInt d = 0; d < BASIS_DIM; d++) {
169             // Update buffers used
170             pre /= Q;
171             const CeedScalar *in  = d == 0 ? s_chebyshev_coeffs : (d % 2 ? buffer_2 : buffer_1);
172             CeedScalar       *out = d == BASIS_DIM - 1 ? (cur_v + p) : (d % 2 ? buffer_1 : buffer_2);
173 
174             // Build Chebyshev polynomial values
175             ChebyshevPolynomialsAtPoint<BASIS_Q_1D>(coords[elem * v_stride + d * v_comp_stride + p], chebyshev_x);
176 
177             // Contract along middle index
178             for (CeedInt a = 0; a < pre; a++) {
179               for (CeedInt c = 0; c < post; c++) {
180                 CeedScalar v_k = 0;
181 
182                 for (CeedInt b = 0; b < Q; b++) v_k += chebyshev_x[b] * in[(a * Q + b) * post + c];
183                 out[a * post + c] = v_k;
184               }
185             }
186             post *= 1;
187           }
188         }
189       }
190     }
191   }
192 }
193 
194 //------------------------------------------------------------------------------
195 // Grad
196 //------------------------------------------------------------------------------
197 extern "C" __global__ void GradAtPoints(const CeedInt num_elem, const CeedInt is_transpose, const CeedScalar *__restrict__ chebyshev_interp_1d,
198                                         const CeedScalar *__restrict__ coords, const CeedScalar *__restrict__ u, CeedScalar *__restrict__ v) {
199   const CeedInt i = threadIdx.x;
200 
201   __shared__ CeedScalar s_mem[BASIS_Q_1D * BASIS_P_1D + 2 * BASIS_BUF_LEN + POINTS_BUFF_LEN * BASIS_Q_1D];
202   CeedScalar           *s_chebyshev_interp_1d = s_mem;
203   CeedScalar           *s_buffer_1            = s_mem + BASIS_Q_1D * BASIS_P_1D;
204   CeedScalar           *s_buffer_2            = s_buffer_1 + BASIS_BUF_LEN;
205   CeedScalar           *s_chebyshev_coeffs    = s_buffer_2 + BASIS_BUF_LEN;
206   CeedScalar            chebyshev_x[BASIS_Q_1D], buffer_1[POINTS_BUFF_LEN], buffer_2[POINTS_BUFF_LEN];
207   for (CeedInt k = i; k < BASIS_Q_1D * BASIS_P_1D; k += blockDim.x) {
208     s_chebyshev_interp_1d[k] = chebyshev_interp_1d[k];
209   }
210 
211   const CeedInt P             = BASIS_P_1D;
212   const CeedInt Q             = BASIS_Q_1D;
213   const CeedInt u_stride      = is_transpose ? BASIS_NUM_PTS : BASIS_NUM_NODES;
214   const CeedInt v_stride      = is_transpose ? BASIS_NUM_NODES : BASIS_NUM_PTS;
215   const CeedInt u_comp_stride = num_elem * (is_transpose ? BASIS_NUM_PTS : BASIS_NUM_NODES);
216   const CeedInt v_comp_stride = num_elem * (is_transpose ? BASIS_NUM_NODES : BASIS_NUM_PTS);
217   const CeedInt u_size        = is_transpose ? BASIS_NUM_PTS : BASIS_NUM_NODES;
218   const CeedInt u_dim_stride  = is_transpose ? num_elem * BASIS_NUM_PTS * BASIS_NUM_COMP : 0;
219   const CeedInt v_dim_stride  = is_transpose ? 0 : num_elem * BASIS_NUM_PTS * BASIS_NUM_COMP;
220 
221   // Apply basis element by element
222   if (is_transpose) {
223     for (CeedInt elem = blockIdx.x; elem < num_elem; elem += gridDim.x) {
224       for (CeedInt comp = 0; comp < BASIS_NUM_COMP; comp++) {
225         CeedScalar *cur_v = v + elem * v_stride + comp * v_comp_stride;
226         CeedInt     pre   = 1;
227         CeedInt     post  = 1;
228 
229         // Clear Chebyshev coeffs
230         for (CeedInt k = i; k < BASIS_NUM_QPTS; k += blockDim.x) {
231           s_chebyshev_coeffs[k] = 0.0;
232         }
233 
234         // Map from point
235         __syncthreads();
236         for (CeedInt p = threadIdx.x; p < BASIS_NUM_PTS; p += blockDim.x) {
237           for (CeedInt dim_1 = 0; dim_1 < BASIS_DIM; dim_1++) {
238             const CeedScalar *cur_u = u + elem * u_stride + dim_1 * u_dim_stride + comp * u_comp_stride;
239 
240             pre  = 1;
241             post = 1;
242             for (CeedInt dim_2 = 0; dim_2 < BASIS_DIM; dim_2++) {
243               // Update buffers used
244               pre /= 1;
245               const CeedScalar *in  = dim_2 == 0 ? (cur_u + p) : (dim_2 % 2 ? buffer_2 : buffer_1);
246               CeedScalar       *out = dim_2 == BASIS_DIM - 1 ? s_chebyshev_coeffs : (dim_2 % 2 ? buffer_1 : buffer_2);
247 
248               // Build Chebyshev polynomial values
249               if (dim_1 == dim_2) ChebyshevDerivativeAtPoint<BASIS_Q_1D>(coords[elem * u_stride + dim_2 * u_comp_stride + p], chebyshev_x);
250               else ChebyshevPolynomialsAtPoint<BASIS_Q_1D>(coords[elem * u_stride + dim_2 * u_comp_stride + p], chebyshev_x);
251 
252               // Contract along middle index
253               for (CeedInt a = 0; a < pre; a++) {
254                 for (CeedInt c = 0; c < post; c++) {
255                   if (dim_2 == BASIS_DIM - 1) {
256                     for (CeedInt j = 0; j < Q; j++) atomicAdd(&out[(a * Q + (j + p) % Q) * post + c], chebyshev_x[(j + p) % Q] * in[a * post + c]);
257                   } else {
258                     for (CeedInt j = 0; j < Q; j++) out[(a * Q + j) * post + c] = chebyshev_x[j] * in[a * post + c];
259                   }
260                 }
261               }
262               post *= Q;
263             }
264           }
265         }
266 
267         // Map from coefficients
268         pre  = BASIS_NUM_QPTS;
269         post = 1;
270         for (CeedInt d = 0; d < BASIS_DIM; d++) {
271           __syncthreads();
272           // Update buffers used
273           pre /= Q;
274           const CeedScalar *in       = d == 0 ? s_chebyshev_coeffs : (d % 2 ? s_buffer_2 : s_buffer_1);
275           CeedScalar       *out      = d == BASIS_DIM - 1 ? cur_v : (d % 2 ? s_buffer_1 : s_buffer_2);
276           const CeedInt     writeLen = pre * post * P;
277 
278           // Contract along middle index
279           for (CeedInt k = i; k < writeLen; k += blockDim.x) {
280             const CeedInt c   = k % post;
281             const CeedInt j   = (k / post) % P;
282             const CeedInt a   = k / (post * P);
283             CeedScalar    v_k = 0;
284 
285             for (CeedInt b = 0; b < Q; b++) v_k += s_chebyshev_interp_1d[j + b * BASIS_P_1D] * in[(a * Q + b) * post + c];
286             out[k] = v_k;
287           }
288           post *= P;
289         }
290       }
291     }
292   } else {
293     for (CeedInt elem = blockIdx.x; elem < num_elem; elem += gridDim.x) {
294       for (CeedInt comp = 0; comp < BASIS_NUM_COMP; comp++) {
295         const CeedScalar *cur_u = u + elem * u_stride + comp * u_comp_stride;
296         CeedInt           pre   = u_size;
297         CeedInt           post  = 1;
298 
299         // Map to coefficients
300         for (CeedInt d = 0; d < BASIS_DIM; d++) {
301           __syncthreads();
302           // Update buffers used
303           pre /= P;
304           const CeedScalar *in       = d == 0 ? cur_u : (d % 2 ? s_buffer_2 : s_buffer_1);
305           CeedScalar       *out      = d == BASIS_DIM - 1 ? s_chebyshev_coeffs : (d % 2 ? s_buffer_1 : s_buffer_2);
306           const CeedInt     writeLen = pre * post * Q;
307 
308           // Contract along middle index
309           for (CeedInt k = i; k < writeLen; k += blockDim.x) {
310             const CeedInt c   = k % post;
311             const CeedInt j   = (k / post) % Q;
312             const CeedInt a   = k / (post * Q);
313             CeedScalar    v_k = 0;
314 
315             for (CeedInt b = 0; b < P; b++) v_k += s_chebyshev_interp_1d[j * BASIS_P_1D + b] * in[(a * P + b) * post + c];
316             out[k] = v_k;
317           }
318           post *= Q;
319         }
320 
321         // Map to point
322         __syncthreads();
323         for (CeedInt p = threadIdx.x; p < BASIS_NUM_PTS; p += blockDim.x) {
324           for (CeedInt dim_1 = 0; dim_1 < BASIS_DIM; dim_1++) {
325             CeedScalar *cur_v = v + elem * v_stride + dim_1 * v_dim_stride + comp * v_comp_stride;
326 
327             pre  = BASIS_NUM_QPTS;
328             post = 1;
329             for (CeedInt dim_2 = 0; dim_2 < BASIS_DIM; dim_2++) {
330               // Update buffers used
331               pre /= Q;
332               const CeedScalar *in  = dim_2 == 0 ? s_chebyshev_coeffs : (dim_2 % 2 ? buffer_2 : buffer_1);
333               CeedScalar       *out = dim_2 == BASIS_DIM - 1 ? (cur_v + p) : (dim_2 % 2 ? buffer_1 : buffer_2);
334 
335               // Build Chebyshev polynomial values
336               if (dim_1 == dim_2) ChebyshevDerivativeAtPoint<BASIS_Q_1D>(coords[elem * v_stride + dim_2 * v_comp_stride + p], chebyshev_x);
337               else ChebyshevPolynomialsAtPoint<BASIS_Q_1D>(coords[elem * v_stride + dim_2 * v_comp_stride + p], chebyshev_x);
338 
339               // Contract along middle index
340               for (CeedInt a = 0; a < pre; a++) {
341                 for (CeedInt c = 0; c < post; c++) {
342                   CeedScalar v_k = 0;
343 
344                   for (CeedInt b = 0; b < Q; b++) v_k += chebyshev_x[b] * in[(a * Q + b) * post + c];
345                   out[a * post + c] = v_k;
346                 }
347               }
348               post *= 1;
349             }
350           }
351         }
352       }
353     }
354   }
355 }
356