1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED: http://github.com/ceed
7
8 /// @file
9 /// Internal header for SYCL shared memory tensor product basis
10 #include <ceed/types.h>
11
12 #include "sycl-shared-basis-read-write-templates.h"
13 #include "sycl-shared-basis-tensor-templates.h"
14
15 //
16 // BASIS_NUM_NODES = CeedIntPow(BASIS_P_1D,DIM)
17 // BASIS_NUM_QPTS = CeedIntPow(BASIS_Q_1D,DIM)
18
19 //------------------------------------------------------------------------------
20 // Interp kernel by dim
21 //------------------------------------------------------------------------------
Interp(const CeedInt num_elem,global const CeedScalar * restrict d_interp_1d,global const CeedScalar * restrict d_U,global CeedScalar * restrict d_V)22 kernel void Interp(const CeedInt num_elem, global const CeedScalar *restrict d_interp_1d, global const CeedScalar *restrict d_U,
23 global CeedScalar *restrict d_V) {
24 local CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D];
25 private
26 CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
27 private
28 CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
29
30 local CeedScalar scratch[BASIS_INTERP_SCRATCH_SIZE];
31 local CeedScalar *elem_scratch = scratch + get_local_id(2) * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
32
33 loadMatrix(BASIS_P_1D * BASIS_Q_1D, d_interp_1d, s_B);
34 work_group_barrier(CLK_LOCAL_MEM_FENCE);
35
36 if (BASIS_DIM == 1) {
37 ReadElementStrided1d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U);
38 Interp1d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch);
39 WriteElementStrided1d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V);
40
41 } else if (BASIS_DIM == 2) {
42 ReadElementStrided2d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U);
43 InterpTensor2d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch);
44 WriteElementStrided2d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V);
45
46 } else if (BASIS_DIM == 3) {
47 ReadElementStrided3d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U);
48 InterpTensor3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch);
49 WriteElementStrided3d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V);
50 }
51 }
52
InterpTranspose(const CeedInt num_elem,global const CeedScalar * restrict d_interp_1d,global const CeedScalar * restrict d_U,global CeedScalar * restrict d_V)53 kernel void InterpTranspose(const CeedInt num_elem, global const CeedScalar *restrict d_interp_1d, global const CeedScalar *restrict d_U,
54 global CeedScalar *restrict d_V) {
55 // local size:
56 // 1d: elems_per_block * T_1d
57 // 2d,3d: elems_per_block * T_1d * T_1d
58 local CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D];
59 private
60 CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
61 private
62 CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
63
64 local CeedScalar scratch[BASIS_INTERP_SCRATCH_SIZE];
65 local CeedScalar *elem_scratch = scratch + get_local_id(2) * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
66
67 loadMatrix(BASIS_P_1D * BASIS_Q_1D, d_interp_1d, s_B);
68 work_group_barrier(CLK_LOCAL_MEM_FENCE);
69
70 if (BASIS_DIM == 1) {
71 ReadElementStrided1d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U);
72 InterpTranspose1d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch);
73 WriteElementStrided1d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V);
74
75 } else if (BASIS_DIM == 2) {
76 ReadElementStrided2d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U);
77 InterpTransposeTensor2d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch);
78 WriteElementStrided2d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V);
79
80 } else if (BASIS_DIM == 3) {
81 ReadElementStrided3d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U);
82 InterpTransposeTensor3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch);
83 WriteElementStrided3d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V);
84 }
85 }
86
87 //------------------------------------------------------------------------------
88 // Grad kernel by dim
89 //------------------------------------------------------------------------------
Grad(const CeedInt num_elem,global const CeedScalar * restrict d_interp_1d,global const CeedScalar * restrict d_grad_1d,global const CeedScalar * restrict d_U,global CeedScalar * restrict d_V)90 kernel void Grad(const CeedInt num_elem, global const CeedScalar *restrict d_interp_1d, global const CeedScalar *restrict d_grad_1d,
91 global const CeedScalar *restrict d_U, global CeedScalar *restrict d_V) {
92 local CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D]; // Todo, don't allocate s_B for dimension 1
93 local CeedScalar s_G[BASIS_Q_1D * (BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D)];
94
95 private
96 CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
97 private
98 CeedScalar r_V[BASIS_NUM_COMP * BASIS_DIM * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
99
100 local CeedScalar scratch[BASIS_GRAD_SCRATCH_SIZE];
101 local CeedScalar *elem_scratch = scratch + get_local_id(2) * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
102
103 loadMatrix(BASIS_P_1D * BASIS_Q_1D, d_interp_1d, s_B);
104 loadMatrix(BASIS_Q_1D * (BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D), d_grad_1d, s_G);
105 work_group_barrier(CLK_LOCAL_MEM_FENCE);
106
107 if (BASIS_DIM == 1) {
108 ReadElementStrided1d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U);
109 Grad1d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_G, r_V, elem_scratch);
110 WriteElementStrided1d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V);
111
112 } else if (BASIS_DIM == 2) {
113 ReadElementStrided2d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U);
114 GradTensor2d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch);
115 WriteElementStrided2d(BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V);
116
117 } else if (BASIS_DIM == 3) {
118 ReadElementStrided3d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U);
119 if (BASIS_HAS_COLLOCATED_GRAD) GradTensorCollocated3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch);
120 else GradTensor3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch);
121 WriteElementStrided3d(BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V);
122 }
123 }
124
GradTranspose(const CeedInt num_elem,global const CeedScalar * restrict d_interp_1d,global const CeedScalar * restrict d_grad_1d,global const CeedScalar * restrict d_U,global CeedScalar * restrict d_V)125 kernel void GradTranspose(const CeedInt num_elem, global const CeedScalar *restrict d_interp_1d, global const CeedScalar *restrict d_grad_1d,
126 global const CeedScalar *restrict d_U, global CeedScalar *restrict d_V) {
127 local CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D]; // Todo, don't allocate s_B for dimension 1
128 local CeedScalar s_G[BASIS_Q_1D * (BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D)];
129
130 private
131 CeedScalar r_U[BASIS_NUM_COMP * BASIS_DIM * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)];
132 private
133 CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)];
134
135 local CeedScalar scratch[BASIS_GRAD_SCRATCH_SIZE];
136 local CeedScalar *elem_scratch = scratch + get_local_id(2) * T_1D * (BASIS_DIM > 1 ? T_1D : 1);
137
138 loadMatrix(BASIS_P_1D * BASIS_Q_1D, d_interp_1d, s_B);
139 loadMatrix(BASIS_Q_1D * (BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D), d_grad_1d, s_G);
140 work_group_barrier(CLK_LOCAL_MEM_FENCE);
141
142 if (BASIS_DIM == 1) {
143 ReadElementStrided1d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U);
144 GradTranspose1d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_G, r_V, elem_scratch);
145 WriteElementStrided1d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V);
146
147 } else if (BASIS_DIM == 2) {
148 ReadElementStrided2d(BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U);
149 GradTransposeTensor2d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch);
150 WriteElementStrided2d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V);
151
152 } else if (BASIS_DIM == 3) {
153 ReadElementStrided3d(BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U);
154 if (BASIS_HAS_COLLOCATED_GRAD) GradTransposeTensorCollocated3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch);
155 else GradTransposeTensor3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch);
156 WriteElementStrided3d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V);
157 }
158 }
159
160 //------------------------------------------------------------------------------
161 // Weight kernels by dim
162 //------------------------------------------------------------------------------
Weight(const CeedInt num_elem,global const CeedScalar * restrict q_weight_1d,global CeedScalar * restrict d_W)163 kernel void Weight(const CeedInt num_elem, global const CeedScalar *restrict q_weight_1d, global CeedScalar *restrict d_W) {
164 private
165 CeedScalar r_W[BASIS_DIM > 2 ? BASIS_Q_1D : 1];
166
167 // void prefetch(q_weight_1d,BASIS_Q_1D);
168
169 if (BASIS_DIM == 1) {
170 Weight1d(BASIS_Q_1D, q_weight_1d, r_W);
171 WriteElementStrided1d(1, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_W, d_W);
172
173 } else if (BASIS_DIM == 2) {
174 WeightTensor2d(BASIS_Q_1D, q_weight_1d, r_W);
175 WriteElementStrided2d(1, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_W, d_W);
176
177 } else if (BASIS_DIM == 3) {
178 WeightTensor3d(BASIS_Q_1D, q_weight_1d, r_W);
179 WriteElementStrided3d(1, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_W, d_W);
180 }
181 }
182