xref: /libCEED/include/ceed/jit-source/hip/hip-ref-basis-tensor.h (revision 381e65939e85104561074440c4dd3dd99bd0efff)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #include <ceed/ceed.h>
9 //------------------------------------------------------------------------------
10 // Tensor Basis Kernels
11 //------------------------------------------------------------------------------
12 
13 //------------------------------------------------------------------------------
14 // Interp
15 //------------------------------------------------------------------------------
16 extern "C" __global__ void Interp(const CeedInt num_elem, const CeedInt transpose,
17                                   const CeedScalar *__restrict__ interp_1d,
18                                   const CeedScalar *__restrict__ u,
19                                   CeedScalar *__restrict__ v) {
20   const CeedInt i = threadIdx.x;
21 
22   __shared__ CeedScalar s_mem[BASIS_Q_1D * BASIS_P_1D + 2 * BASIS_BUF_LEN];
23   CeedScalar *s_interp_1d = s_mem;
24   CeedScalar *s_buffer_1 = s_mem + BASIS_Q_1D * BASIS_P_1D;
25   CeedScalar *s_buffer_2 = s_buffer_1 + BASIS_BUF_LEN;
26   for (CeedInt k = i; k < BASIS_Q_1D * BASIS_P_1D; k += blockDim.x) {
27     s_interp_1d[k] = interp_1d[k];
28   }
29 
30   const CeedInt P = transpose ? BASIS_Q_1D : BASIS_P_1D;
31   const CeedInt Q = transpose ? BASIS_P_1D : BASIS_Q_1D;
32   const CeedInt stride0 = transpose ? 1 : BASIS_P_1D;
33   const CeedInt stride1 = transpose ? BASIS_P_1D : 1;
34   const CeedInt u_stride = transpose ? BASIS_NUM_QPTS : BASIS_NUM_NODES;
35   const CeedInt v_stride = transpose ? BASIS_NUM_NODES : BASIS_NUM_QPTS;
36   const CeedInt u_comp_stride = num_elem * (transpose ? BASIS_NUM_QPTS : BASIS_NUM_NODES);
37   const CeedInt v_comp_stride = num_elem * (transpose ? BASIS_NUM_NODES : BASIS_NUM_QPTS);
38   const CeedInt u_size = transpose ? BASIS_NUM_QPTS : BASIS_NUM_NODES;
39 
40   // Apply basis element by element
41   for (CeedInt elem = blockIdx.x; elem < num_elem; elem += gridDim.x) {
42     for (CeedInt comp = 0; comp < BASIS_NUM_COMP; comp++) {
43       const CeedScalar *cur_u = u + elem * u_stride + comp * u_comp_stride;
44       CeedScalar *cur_v = v + elem * v_stride + comp * v_comp_stride;
45       for (CeedInt k = i; k < u_size; k += blockDim.x) {
46         s_buffer_1[k] = cur_u[k];
47       }
48       CeedInt pre = u_size;
49       CeedInt post = 1;
50       for (CeedInt d = 0; d < BASIS_DIM; d++) {
51         __syncthreads();
52         // Update buffers used
53         pre /= P;
54         const CeedScalar *in = d % 2 ? s_buffer_2 : s_buffer_1;
55         CeedScalar *out = d == BASIS_DIM - 1 ? cur_v : (d % 2 ? s_buffer_1 : s_buffer_2);
56 
57         // Contract along middle index
58         const CeedInt writeLen = pre * post * Q;
59         for (CeedInt k = i; k < writeLen; k += blockDim.x) {
60           const CeedInt c = k % post;
61           const CeedInt j = (k / post) % Q;
62           const CeedInt a = k / (post * Q);
63 
64           CeedScalar vk = 0;
65           for (CeedInt b = 0; b < P; b++)
66             vk += s_interp_1d[j*stride0 + b*stride1] * in[(a*P + b)*post + c];
67 
68           out[k] = vk;
69         }
70 
71         post *= Q;
72       }
73     }
74   }
75 }
76 
77 //------------------------------------------------------------------------------
78 // Grad
79 //------------------------------------------------------------------------------
80 extern "C" __global__ void Grad(const CeedInt num_elem, const CeedInt transpose,
81                                 const CeedScalar *__restrict__ interp_1d,
82                                 const CeedScalar *__restrict__ grad_1d,
83                                 const CeedScalar *__restrict__ u,
84                                 CeedScalar *__restrict__ v) {
85   const CeedInt i = threadIdx.x;
86 
87   __shared__ CeedScalar s_mem[2 * (BASIS_Q_1D * BASIS_P_1D + BASIS_BUF_LEN)];
88   CeedScalar *s_interp_1d = s_mem;
89   CeedScalar *s_grad_1d = s_interp_1d + BASIS_Q_1D * BASIS_P_1D;
90   CeedScalar *s_buffer_1 = s_grad_1d + BASIS_Q_1D * BASIS_P_1D;
91   CeedScalar *s_buffer_2 = s_buffer_1 + BASIS_BUF_LEN;
92   for (CeedInt k = i; k < BASIS_Q_1D * BASIS_P_1D; k += blockDim.x) {
93     s_interp_1d[k] = interp_1d[k];
94     s_grad_1d[k] = grad_1d[k];
95   }
96 
97   const CeedInt P = transpose ? BASIS_Q_1D : BASIS_P_1D;
98   const CeedInt Q = transpose ? BASIS_P_1D : BASIS_Q_1D;
99   const CeedInt stride0 = transpose ? 1 : BASIS_P_1D;
100   const CeedInt stride1 = transpose ? BASIS_P_1D : 1;
101   const CeedInt u_stride = transpose ? BASIS_NUM_QPTS : BASIS_NUM_NODES;
102   const CeedInt v_stride = transpose ? BASIS_NUM_NODES : BASIS_NUM_QPTS;
103   const CeedInt u_comp_stride = num_elem * (transpose ? BASIS_NUM_QPTS : BASIS_NUM_NODES);
104   const CeedInt v_comp_stride = num_elem * (transpose ? BASIS_NUM_NODES : BASIS_NUM_QPTS);
105   const CeedInt u_dim_stride = transpose ? num_elem * BASIS_NUM_QPTS * BASIS_NUM_COMP : 0;
106   const CeedInt v_dim_stride = transpose ? 0 : num_elem * BASIS_NUM_QPTS * BASIS_NUM_COMP;
107 
108   // Apply basis element by element
109   for (CeedInt elem = blockIdx.x; elem < num_elem; elem += gridDim.x) {
110     for (CeedInt comp = 0; comp < BASIS_NUM_COMP; comp++) {
111 
112       // dim*dim contractions for grad
113       for (CeedInt dim_1 = 0; dim_1 < BASIS_DIM; dim_1++) {
114         CeedInt pre = transpose ? BASIS_NUM_QPTS : BASIS_NUM_NODES;
115         CeedInt post = 1;
116         const CeedScalar *cur_u = u + elem * u_stride + dim_1 * u_dim_stride +
117                                   comp * u_comp_stride;
118         CeedScalar *cur_v = v + elem * v_stride + dim_1 * v_dim_stride + comp *
119                             v_comp_stride;
120         for (CeedInt dim_2 = 0; dim_2 < BASIS_DIM; dim_2++) {
121           __syncthreads();
122           // Update buffers used
123           pre /= P;
124           const CeedScalar *op = dim_1 == dim_2 ? s_grad_1d : s_interp_1d;
125           const CeedScalar *in = dim_2 == 0
126                                  ? cur_u
127                                  : (dim_2 % 2 ? s_buffer_2 : s_buffer_1);
128           CeedScalar *out = dim_2 == BASIS_DIM - 1
129                             ? cur_v
130                             : (dim_2 % 2 ? s_buffer_1 : s_buffer_2);
131 
132           // Contract along middle index
133           const CeedInt writeLen = pre * post * Q;
134           for (CeedInt k = i; k < writeLen; k += blockDim.x) {
135             const CeedInt c = k % post;
136             const CeedInt j = (k / post) % Q;
137             const CeedInt a = k / (post * Q);
138             CeedScalar v_k = 0;
139             for (CeedInt b = 0; b < P; b++)
140               v_k += op[j * stride0 + b * stride1] * in[(a * P + b) * post + c];
141 
142             if (transpose && dim_2 == BASIS_DIM - 1)
143               out[k] += v_k;
144             else
145               out[k] = v_k;
146           }
147 
148           post *= Q;
149         }
150       }
151     }
152   }
153 }
154 
155 //------------------------------------------------------------------------------
156 // 1D quadrature weights
157 //------------------------------------------------------------------------------
158 __device__ void Weight1d(const CeedInt num_elem, const CeedScalar *q_weight_1d,
159                          CeedScalar *w) {
160   CeedScalar w1d[BASIS_Q_1D];
161   for (CeedInt i = 0; i < BASIS_Q_1D; i++)
162     w1d[i] = q_weight_1d[i];
163 
164   for (CeedInt e = blockIdx.x * blockDim.x + threadIdx.x;
165        e < num_elem;
166        e += blockDim.x * gridDim.x)
167     for (CeedInt i = 0; i < BASIS_Q_1D; i++) {
168       const CeedInt ind = e*BASIS_Q_1D + i; // sequential
169       w[ind] = w1d[i];
170     }
171 }
172 
173 //------------------------------------------------------------------------------
174 // 2D quadrature weights
175 //------------------------------------------------------------------------------
176 __device__ void Weight2d(const CeedInt num_elem, const CeedScalar *q_weight_1d,
177                          CeedScalar *w) {
178   CeedScalar w1d[BASIS_Q_1D];
179   for (CeedInt i = 0; i < BASIS_Q_1D; i++)
180     w1d[i] = q_weight_1d[i];
181 
182   for (CeedInt e = blockIdx.x * blockDim.x + threadIdx.x;
183        e < num_elem;
184        e += blockDim.x * gridDim.x)
185     for (CeedInt i = 0; i < BASIS_Q_1D; i++)
186       for (CeedInt j = 0; j < BASIS_Q_1D; j++) {
187         const CeedInt ind = e*BASIS_Q_1D*BASIS_Q_1D + i + j*BASIS_Q_1D; // sequential
188         w[ind] = w1d[i]*w1d[j];
189       }
190 }
191 
192 //------------------------------------------------------------------------------
193 // 3D quadrature weights
194 //------------------------------------------------------------------------------
195 __device__ void Weight3d(const CeedInt num_elem, const CeedScalar *q_weight_1d,
196                          CeedScalar *w) {
197   CeedScalar w1d[BASIS_Q_1D];
198   for (CeedInt i = 0; i < BASIS_Q_1D; i++)
199     w1d[i] = q_weight_1d[i];
200 
201   for (CeedInt e = blockIdx.x * blockDim.x + threadIdx.x;
202        e < num_elem;
203        e += blockDim.x * gridDim.x)
204     for (CeedInt i = 0; i < BASIS_Q_1D; i++)
205       for (CeedInt j = 0; j < BASIS_Q_1D; j++)
206         for (CeedInt k = 0; k < BASIS_Q_1D; k++) {
207           const CeedInt ind = e*BASIS_Q_1D*BASIS_Q_1D*BASIS_Q_1D + i +
208                               j*BASIS_Q_1D + k*BASIS_Q_1D*BASIS_Q_1D; // sequential
209           w[ind] = w1d[i]*w1d[j]*w1d[k];
210         }
211 }
212 
213 //------------------------------------------------------------------------------
214 // Quadrature weights
215 //------------------------------------------------------------------------------
216 extern "C" __global__ void Weight(const CeedInt num_elem,
217                                   const CeedScalar *__restrict__ q_weight_1d,
218                                   CeedScalar *__restrict__ v) {
219   if (BASIS_DIM == 1)
220     Weight1d(num_elem, q_weight_1d, v);
221   else if (BASIS_DIM == 2)
222     Weight2d(num_elem, q_weight_1d, v);
223   else if (BASIS_DIM == 3)
224     Weight3d(num_elem, q_weight_1d, v);
225 }
226 
227 //------------------------------------------------------------------------------
228