xref: /libCEED/include/ceed/jit-source/hip/hip-ref-basis-tensor.h (revision a0154adecfab8547cdc0febbbf40ac009dbe9d1d)
1*a0154adeSJed Brown // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*a0154adeSJed Brown // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*a0154adeSJed Brown //
4*a0154adeSJed Brown // SPDX-License-Identifier: BSD-2-Clause
5*a0154adeSJed Brown //
6*a0154adeSJed Brown // This file is part of CEED:  http://github.com/ceed
7*a0154adeSJed Brown 
8*a0154adeSJed Brown #include <ceed/ceed.h>
9*a0154adeSJed Brown //------------------------------------------------------------------------------
10*a0154adeSJed Brown // Tensor Basis Kernels
11*a0154adeSJed Brown //------------------------------------------------------------------------------
12*a0154adeSJed Brown 
13*a0154adeSJed Brown //------------------------------------------------------------------------------
14*a0154adeSJed Brown // Interp
15*a0154adeSJed Brown //------------------------------------------------------------------------------
16*a0154adeSJed Brown extern "C" __global__ void Interp(const CeedInt num_elem, const CeedInt transpose,
17*a0154adeSJed Brown                                   const CeedScalar *__restrict__ interp_1d,
18*a0154adeSJed Brown                                   const CeedScalar *__restrict__ u,
19*a0154adeSJed Brown                                   CeedScalar *__restrict__ v) {
20*a0154adeSJed Brown   const CeedInt i = threadIdx.x;
21*a0154adeSJed Brown 
22*a0154adeSJed Brown   __shared__ CeedScalar s_mem[BASIS_Q_1D * BASIS_P_1D + 2 * BASIS_BUF_LEN];
23*a0154adeSJed Brown   CeedScalar *s_interp_1d = s_mem;
24*a0154adeSJed Brown   CeedScalar *s_buffer_1 = s_mem + BASIS_Q_1D * BASIS_P_1D;
25*a0154adeSJed Brown   CeedScalar *s_buffer_2 = s_buffer_1 + BASIS_BUF_LEN;
26*a0154adeSJed Brown   for (CeedInt k = i; k < BASIS_Q_1D * BASIS_P_1D; k += blockDim.x) {
27*a0154adeSJed Brown     s_interp_1d[k] = interp_1d[k];
28*a0154adeSJed Brown   }
29*a0154adeSJed Brown 
30*a0154adeSJed Brown   const CeedInt P = transpose ? BASIS_Q_1D : BASIS_P_1D;
31*a0154adeSJed Brown   const CeedInt Q = transpose ? BASIS_P_1D : BASIS_Q_1D;
32*a0154adeSJed Brown   const CeedInt stride0 = transpose ? 1 : BASIS_P_1D;
33*a0154adeSJed Brown   const CeedInt stride1 = transpose ? BASIS_P_1D : 1;
34*a0154adeSJed Brown   const CeedInt u_stride = transpose ? BASIS_NUM_QPTS : BASIS_NUM_NODES;
35*a0154adeSJed Brown   const CeedInt v_stride = transpose ? BASIS_NUM_NODES : BASIS_NUM_QPTS;
36*a0154adeSJed Brown   const CeedInt u_comp_stride = num_elem * (transpose ? BASIS_NUM_QPTS : BASIS_NUM_NODES);
37*a0154adeSJed Brown   const CeedInt v_comp_stride = num_elem * (transpose ? BASIS_NUM_NODES : BASIS_NUM_QPTS);
38*a0154adeSJed Brown   const CeedInt u_size = transpose ? BASIS_NUM_QPTS : BASIS_NUM_NODES;
39*a0154adeSJed Brown 
40*a0154adeSJed Brown   // Apply basis element by element
41*a0154adeSJed Brown   for (CeedInt elem = blockIdx.x; elem < num_elem; elem += gridDim.x) {
42*a0154adeSJed Brown     for (CeedInt comp = 0; comp < BASIS_NUM_COMP; comp++) {
43*a0154adeSJed Brown       const CeedScalar *cur_u = u + elem * u_stride + comp * u_comp_stride;
44*a0154adeSJed Brown       CeedScalar *cur_v = v + elem * v_stride + comp * v_comp_stride;
45*a0154adeSJed Brown       for (CeedInt k = i; k < u_size; k += blockDim.x) {
46*a0154adeSJed Brown         s_buffer_1[k] = cur_u[k];
47*a0154adeSJed Brown       }
48*a0154adeSJed Brown       CeedInt pre = u_size;
49*a0154adeSJed Brown       CeedInt post = 1;
50*a0154adeSJed Brown       for (CeedInt d = 0; d < BASIS_DIM; d++) {
51*a0154adeSJed Brown         __syncthreads();
52*a0154adeSJed Brown         // Update buffers used
53*a0154adeSJed Brown         pre /= P;
54*a0154adeSJed Brown         const CeedScalar *in = d % 2 ? s_buffer_2 : s_buffer_1;
55*a0154adeSJed Brown         CeedScalar *out = d == BASIS_DIM - 1 ? cur_v : (d % 2 ? s_buffer_1 : s_buffer_2);
56*a0154adeSJed Brown 
57*a0154adeSJed Brown         // Contract along middle index
58*a0154adeSJed Brown         const CeedInt writeLen = pre * post * Q;
59*a0154adeSJed Brown         for (CeedInt k = i; k < writeLen; k += blockDim.x) {
60*a0154adeSJed Brown           const CeedInt c = k % post;
61*a0154adeSJed Brown           const CeedInt j = (k / post) % Q;
62*a0154adeSJed Brown           const CeedInt a = k / (post * Q);
63*a0154adeSJed Brown 
64*a0154adeSJed Brown           CeedScalar vk = 0;
65*a0154adeSJed Brown           for (CeedInt b = 0; b < P; b++)
66*a0154adeSJed Brown             vk += s_interp_1d[j*stride0 + b*stride1] * in[(a*P + b)*post + c];
67*a0154adeSJed Brown 
68*a0154adeSJed Brown           out[k] = vk;
69*a0154adeSJed Brown         }
70*a0154adeSJed Brown 
71*a0154adeSJed Brown         post *= Q;
72*a0154adeSJed Brown       }
73*a0154adeSJed Brown     }
74*a0154adeSJed Brown   }
75*a0154adeSJed Brown }
76*a0154adeSJed Brown 
77*a0154adeSJed Brown //------------------------------------------------------------------------------
78*a0154adeSJed Brown // Grad
79*a0154adeSJed Brown //------------------------------------------------------------------------------
80*a0154adeSJed Brown extern "C" __global__ void Grad(const CeedInt num_elem, const CeedInt transpose,
81*a0154adeSJed Brown                                 const CeedScalar *__restrict__ interp_1d,
82*a0154adeSJed Brown                                 const CeedScalar *__restrict__ grad_1d,
83*a0154adeSJed Brown                                 const CeedScalar *__restrict__ u,
84*a0154adeSJed Brown                                 CeedScalar *__restrict__ v) {
85*a0154adeSJed Brown   const CeedInt i = threadIdx.x;
86*a0154adeSJed Brown 
87*a0154adeSJed Brown   __shared__ CeedScalar s_mem[2 * (BASIS_Q_1D * BASIS_P_1D + BASIS_BUF_LEN)];
88*a0154adeSJed Brown   CeedScalar *s_interp_1d = s_mem;
89*a0154adeSJed Brown   CeedScalar *s_grad_1d = s_interp_1d + BASIS_Q_1D * BASIS_P_1D;
90*a0154adeSJed Brown   CeedScalar *s_buffer_1 = s_grad_1d + BASIS_Q_1D * BASIS_P_1D;
91*a0154adeSJed Brown   CeedScalar *s_buffer_2 = s_buffer_1 + BASIS_BUF_LEN;
92*a0154adeSJed Brown   for (CeedInt k = i; k < BASIS_Q_1D * BASIS_P_1D; k += blockDim.x) {
93*a0154adeSJed Brown     s_interp_1d[k] = interp_1d[k];
94*a0154adeSJed Brown     s_grad_1d[k] = grad_1d[k];
95*a0154adeSJed Brown   }
96*a0154adeSJed Brown 
97*a0154adeSJed Brown   const CeedInt P = transpose ? BASIS_Q_1D : BASIS_P_1D;
98*a0154adeSJed Brown   const CeedInt Q = transpose ? BASIS_P_1D : BASIS_Q_1D;
99*a0154adeSJed Brown   const CeedInt stride0 = transpose ? 1 : BASIS_P_1D;
100*a0154adeSJed Brown   const CeedInt stride1 = transpose ? BASIS_P_1D : 1;
101*a0154adeSJed Brown   const CeedInt u_stride = transpose ? BASIS_NUM_QPTS : BASIS_NUM_NODES;
102*a0154adeSJed Brown   const CeedInt v_stride = transpose ? BASIS_NUM_NODES : BASIS_NUM_QPTS;
103*a0154adeSJed Brown   const CeedInt u_comp_stride = num_elem * (transpose ? BASIS_NUM_QPTS : BASIS_NUM_NODES);
104*a0154adeSJed Brown   const CeedInt v_comp_stride = num_elem * (transpose ? BASIS_NUM_NODES : BASIS_NUM_QPTS);
105*a0154adeSJed Brown   const CeedInt u_dim_stride = transpose ? num_elem * BASIS_NUM_QPTS * BASIS_NUM_COMP : 0;
106*a0154adeSJed Brown   const CeedInt v_dim_stride = transpose ? 0 : num_elem * BASIS_NUM_QPTS * BASIS_NUM_COMP;
107*a0154adeSJed Brown 
108*a0154adeSJed Brown   // Apply basis element by element
109*a0154adeSJed Brown   for (CeedInt elem = blockIdx.x; elem < num_elem; elem += gridDim.x) {
110*a0154adeSJed Brown     for (CeedInt comp = 0; comp < BASIS_NUM_COMP; comp++) {
111*a0154adeSJed Brown 
112*a0154adeSJed Brown       // dim*dim contractions for grad
113*a0154adeSJed Brown       for (CeedInt dim_1 = 0; dim_1 < BASIS_DIM; dim_1++) {
114*a0154adeSJed Brown         CeedInt pre = transpose ? BASIS_NUM_QPTS : BASIS_NUM_NODES;
115*a0154adeSJed Brown         CeedInt post = 1;
116*a0154adeSJed Brown         const CeedScalar *cur_u = u + elem * u_stride + dim_1 * u_dim_stride +
117*a0154adeSJed Brown                                   comp * u_comp_stride;
118*a0154adeSJed Brown         CeedScalar *cur_v = v + elem * v_stride + dim_1 * v_dim_stride + comp *
119*a0154adeSJed Brown                             v_comp_stride;
120*a0154adeSJed Brown         for (CeedInt dim_2 = 0; dim_2 < BASIS_DIM; dim_2++) {
121*a0154adeSJed Brown           __syncthreads();
122*a0154adeSJed Brown           // Update buffers used
123*a0154adeSJed Brown           pre /= P;
124*a0154adeSJed Brown           const CeedScalar *op = dim_1 == dim_2 ? s_grad_1d : s_interp_1d;
125*a0154adeSJed Brown           const CeedScalar *in = dim_2 == 0
126*a0154adeSJed Brown                                  ? cur_u
127*a0154adeSJed Brown                                  : (dim_2 % 2 ? s_buffer_2 : s_buffer_1);
128*a0154adeSJed Brown           CeedScalar *out = dim_2 == BASIS_DIM - 1
129*a0154adeSJed Brown                             ? cur_v
130*a0154adeSJed Brown                             : (dim_2 % 2 ? s_buffer_1 : s_buffer_2);
131*a0154adeSJed Brown 
132*a0154adeSJed Brown           // Contract along middle index
133*a0154adeSJed Brown           const CeedInt writeLen = pre * post * Q;
134*a0154adeSJed Brown           for (CeedInt k = i; k < writeLen; k += blockDim.x) {
135*a0154adeSJed Brown             const CeedInt c = k % post;
136*a0154adeSJed Brown             const CeedInt j = (k / post) % Q;
137*a0154adeSJed Brown             const CeedInt a = k / (post * Q);
138*a0154adeSJed Brown             CeedScalar v_k = 0;
139*a0154adeSJed Brown             for (CeedInt b = 0; b < P; b++)
140*a0154adeSJed Brown               v_k += op[j * stride0 + b * stride1] * in[(a * P + b) * post + c];
141*a0154adeSJed Brown 
142*a0154adeSJed Brown             if (transpose && dim_2 == BASIS_DIM - 1)
143*a0154adeSJed Brown               out[k] += v_k;
144*a0154adeSJed Brown             else
145*a0154adeSJed Brown               out[k] = v_k;
146*a0154adeSJed Brown           }
147*a0154adeSJed Brown 
148*a0154adeSJed Brown           post *= Q;
149*a0154adeSJed Brown         }
150*a0154adeSJed Brown       }
151*a0154adeSJed Brown     }
152*a0154adeSJed Brown   }
153*a0154adeSJed Brown }
154*a0154adeSJed Brown 
155*a0154adeSJed Brown //------------------------------------------------------------------------------
156*a0154adeSJed Brown // 1D quadrature weights
157*a0154adeSJed Brown //------------------------------------------------------------------------------
158*a0154adeSJed Brown __device__ void Weight1d(const CeedInt num_elem, const CeedScalar *q_weight_1d,
159*a0154adeSJed Brown                          CeedScalar *w) {
160*a0154adeSJed Brown   CeedScalar w1d[BASIS_Q_1D];
161*a0154adeSJed Brown   for (CeedInt i = 0; i < BASIS_Q_1D; i++)
162*a0154adeSJed Brown     w1d[i] = q_weight_1d[i];
163*a0154adeSJed Brown 
164*a0154adeSJed Brown   for (CeedInt e = blockIdx.x * blockDim.x + threadIdx.x;
165*a0154adeSJed Brown        e < num_elem;
166*a0154adeSJed Brown        e += blockDim.x * gridDim.x)
167*a0154adeSJed Brown     for (CeedInt i = 0; i < BASIS_Q_1D; i++) {
168*a0154adeSJed Brown       const CeedInt ind = e*BASIS_Q_1D + i; // sequential
169*a0154adeSJed Brown       w[ind] = w1d[i];
170*a0154adeSJed Brown     }
171*a0154adeSJed Brown }
172*a0154adeSJed Brown 
173*a0154adeSJed Brown //------------------------------------------------------------------------------
174*a0154adeSJed Brown // 2D quadrature weights
175*a0154adeSJed Brown //------------------------------------------------------------------------------
176*a0154adeSJed Brown __device__ void Weight2d(const CeedInt num_elem, const CeedScalar *q_weight_1d,
177*a0154adeSJed Brown                          CeedScalar *w) {
178*a0154adeSJed Brown   CeedScalar w1d[BASIS_Q_1D];
179*a0154adeSJed Brown   for (CeedInt i = 0; i < BASIS_Q_1D; i++)
180*a0154adeSJed Brown     w1d[i] = q_weight_1d[i];
181*a0154adeSJed Brown 
182*a0154adeSJed Brown   for (CeedInt e = blockIdx.x * blockDim.x + threadIdx.x;
183*a0154adeSJed Brown        e < num_elem;
184*a0154adeSJed Brown        e += blockDim.x * gridDim.x)
185*a0154adeSJed Brown     for (CeedInt i = 0; i < BASIS_Q_1D; i++)
186*a0154adeSJed Brown       for (CeedInt j = 0; j < BASIS_Q_1D; j++) {
187*a0154adeSJed Brown         const CeedInt ind = e*BASIS_Q_1D*BASIS_Q_1D + i + j*BASIS_Q_1D; // sequential
188*a0154adeSJed Brown         w[ind] = w1d[i]*w1d[j];
189*a0154adeSJed Brown       }
190*a0154adeSJed Brown }
191*a0154adeSJed Brown 
192*a0154adeSJed Brown //------------------------------------------------------------------------------
193*a0154adeSJed Brown // 3D quadrature weights
194*a0154adeSJed Brown //------------------------------------------------------------------------------
195*a0154adeSJed Brown __device__ void Weight3d(const CeedInt num_elem, const CeedScalar *q_weight_1d,
196*a0154adeSJed Brown                          CeedScalar *w) {
197*a0154adeSJed Brown   CeedScalar w1d[BASIS_Q_1D];
198*a0154adeSJed Brown   for (CeedInt i = 0; i < BASIS_Q_1D; i++)
199*a0154adeSJed Brown     w1d[i] = q_weight_1d[i];
200*a0154adeSJed Brown 
201*a0154adeSJed Brown   for (CeedInt e = blockIdx.x * blockDim.x + threadIdx.x;
202*a0154adeSJed Brown        e < num_elem;
203*a0154adeSJed Brown        e += blockDim.x * gridDim.x)
204*a0154adeSJed Brown     for (CeedInt i = 0; i < BASIS_Q_1D; i++)
205*a0154adeSJed Brown       for (CeedInt j = 0; j < BASIS_Q_1D; j++)
206*a0154adeSJed Brown         for (CeedInt k = 0; k < BASIS_Q_1D; k++) {
207*a0154adeSJed Brown           const CeedInt ind = e*BASIS_Q_1D*BASIS_Q_1D*BASIS_Q_1D + i +
208*a0154adeSJed Brown                               j*BASIS_Q_1D + k*BASIS_Q_1D*BASIS_Q_1D; // sequential
209*a0154adeSJed Brown           w[ind] = w1d[i]*w1d[j]*w1d[k];
210*a0154adeSJed Brown         }
211*a0154adeSJed Brown }
212*a0154adeSJed Brown 
213*a0154adeSJed Brown //------------------------------------------------------------------------------
214*a0154adeSJed Brown // Quadrature weights
215*a0154adeSJed Brown //------------------------------------------------------------------------------
216*a0154adeSJed Brown extern "C" __global__ void Weight(const CeedInt num_elem,
217*a0154adeSJed Brown                                   const CeedScalar *__restrict__ q_weight_1d,
218*a0154adeSJed Brown                                   CeedScalar *__restrict__ v) {
219*a0154adeSJed Brown   if (BASIS_DIM == 1)
220*a0154adeSJed Brown     Weight1d(num_elem, q_weight_1d, v);
221*a0154adeSJed Brown   else if (BASIS_DIM == 2)
222*a0154adeSJed Brown     Weight2d(num_elem, q_weight_1d, v);
223*a0154adeSJed Brown   else if (BASIS_DIM == 3)
224*a0154adeSJed Brown     Weight3d(num_elem, q_weight_1d, v);
225*a0154adeSJed Brown }
226*a0154adeSJed Brown 
227*a0154adeSJed Brown //------------------------------------------------------------------------------
228