1*9ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3bd882c8aSJames Wright //
4bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause
5bd882c8aSJames Wright //
6bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed
7bd882c8aSJames Wright
8bd882c8aSJames Wright /// @file
9bd882c8aSJames Wright /// Internal header for SYCL shared memory basis read/write templates
10c0b5abf0SJeremy L Thompson #include <ceed/types.h>
11bd882c8aSJames Wright
12bd882c8aSJames Wright //------------------------------------------------------------------------------
13bd882c8aSJames Wright // Helper function: load matrices for basis actions
14bd882c8aSJames Wright //------------------------------------------------------------------------------
loadMatrix(const CeedInt N,const CeedScalar * restrict d_B,CeedScalar * restrict B)15bd882c8aSJames Wright inline void loadMatrix(const CeedInt N, const CeedScalar *restrict d_B, CeedScalar *restrict B) {
16bd882c8aSJames Wright const CeedInt item_id = get_local_linear_id();
17bd882c8aSJames Wright const CeedInt group_size = get_local_size(0) * get_local_size(1) * get_local_size(2);
18bd882c8aSJames Wright for (CeedInt i = item_id; i < N; i += group_size) B[i] = d_B[i];
19bd882c8aSJames Wright }
20bd882c8aSJames Wright
21bd882c8aSJames Wright //------------------------------------------------------------------------------
22bd882c8aSJames Wright // 1D
23bd882c8aSJames Wright //------------------------------------------------------------------------------
24bd882c8aSJames Wright
25bd882c8aSJames Wright //------------------------------------------------------------------------------
26bd882c8aSJames Wright // E-vector -> single element
27bd882c8aSJames Wright //------------------------------------------------------------------------------
ReadElementStrided1d(const CeedInt NUM_COMP,const CeedInt P_1D,const CeedInt num_elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,global const CeedScalar * restrict d_u,private CeedScalar * restrict r_u)28bd882c8aSJames Wright inline void ReadElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
29bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u,
30bd882c8aSJames Wright private CeedScalar *restrict r_u) {
31bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0);
32bd882c8aSJames Wright const CeedInt elem = get_global_id(2);
33bd882c8aSJames Wright
34bd882c8aSJames Wright if (item_id_x < P_1D && elem < num_elem) {
35bd882c8aSJames Wright const CeedInt node = item_id_x;
36bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem;
37bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
38bd882c8aSJames Wright r_u[comp] = d_u[ind + comp * strides_comp];
39bd882c8aSJames Wright }
40bd882c8aSJames Wright }
41bd882c8aSJames Wright }
42bd882c8aSJames Wright
43bd882c8aSJames Wright //------------------------------------------------------------------------------
44bd882c8aSJames Wright // Single element -> E-vector
45bd882c8aSJames Wright //------------------------------------------------------------------------------
WriteElementStrided1d(const CeedInt NUM_COMP,const CeedInt P_1D,const CeedInt num_elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,private const CeedScalar * restrict r_v,global CeedScalar * restrict d_v)46bd882c8aSJames Wright inline void WriteElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
47bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v,
48bd882c8aSJames Wright global CeedScalar *restrict d_v) {
49bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0);
50bd882c8aSJames Wright const CeedInt elem = get_global_id(2);
51bd882c8aSJames Wright
52bd882c8aSJames Wright if (item_id_x < P_1D && elem < num_elem) {
53bd882c8aSJames Wright const CeedInt node = item_id_x;
54bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem;
55bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
56bd882c8aSJames Wright d_v[ind + comp * strides_comp] = r_v[comp];
57bd882c8aSJames Wright }
58bd882c8aSJames Wright }
59bd882c8aSJames Wright }
60bd882c8aSJames Wright
61bd882c8aSJames Wright //------------------------------------------------------------------------------
62bd882c8aSJames Wright // 2D
63bd882c8aSJames Wright //------------------------------------------------------------------------------
64bd882c8aSJames Wright
65bd882c8aSJames Wright //------------------------------------------------------------------------------
66bd882c8aSJames Wright // E-vector -> single element
67bd882c8aSJames Wright //------------------------------------------------------------------------------
ReadElementStrided2d(const CeedInt NUM_COMP,const CeedInt P_1D,const CeedInt num_elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,global const CeedScalar * restrict d_u,private CeedScalar * restrict r_u)68bd882c8aSJames Wright inline void ReadElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
69bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u,
70bd882c8aSJames Wright private CeedScalar *restrict r_u) {
71bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0);
72bd882c8aSJames Wright const CeedInt item_id_y = get_local_id(1);
73bd882c8aSJames Wright const CeedInt elem = get_global_id(2);
74bd882c8aSJames Wright
75bd882c8aSJames Wright if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
76bd882c8aSJames Wright const CeedInt node = item_id_x + item_id_y * P_1D;
77bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem;
78bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
79bd882c8aSJames Wright r_u[comp] = d_u[ind + comp * strides_comp];
80bd882c8aSJames Wright }
81bd882c8aSJames Wright }
82bd882c8aSJames Wright }
83bd882c8aSJames Wright
84bd882c8aSJames Wright //------------------------------------------------------------------------------
85bd882c8aSJames Wright // Single element -> E-vector
86bd882c8aSJames Wright //------------------------------------------------------------------------------
WriteElementStrided2d(const CeedInt NUM_COMP,const CeedInt P_1D,const CeedInt num_elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,private const CeedScalar * restrict r_v,global CeedScalar * restrict d_v)87bd882c8aSJames Wright inline void WriteElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
88bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v,
89bd882c8aSJames Wright global CeedScalar *restrict d_v) {
90bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0);
91bd882c8aSJames Wright const CeedInt item_id_y = get_local_id(1);
92bd882c8aSJames Wright const CeedInt elem = get_global_id(2);
93bd882c8aSJames Wright
94bd882c8aSJames Wright if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
95bd882c8aSJames Wright const CeedInt node = item_id_x + item_id_y * P_1D;
96bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem;
97bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
98bd882c8aSJames Wright d_v[ind + comp * strides_comp] = r_v[comp];
99bd882c8aSJames Wright }
100bd882c8aSJames Wright }
101bd882c8aSJames Wright }
102bd882c8aSJames Wright
103bd882c8aSJames Wright //------------------------------------------------------------------------------
104bd882c8aSJames Wright // 3D
105bd882c8aSJames Wright //------------------------------------------------------------------------------
106bd882c8aSJames Wright
107bd882c8aSJames Wright //------------------------------------------------------------------------------
108bd882c8aSJames Wright // E-vector -> single element
109bd882c8aSJames Wright //------------------------------------------------------------------------------
ReadElementStrided3d(const CeedInt NUM_COMP,const CeedInt P_1D,const CeedInt num_elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,global const CeedScalar * restrict d_u,private CeedScalar * restrict r_u)110bd882c8aSJames Wright inline void ReadElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
111bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u,
112bd882c8aSJames Wright private CeedScalar *restrict r_u) {
113bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0);
114bd882c8aSJames Wright const CeedInt item_id_y = get_local_id(1);
115bd882c8aSJames Wright const CeedInt elem = get_global_id(2);
116bd882c8aSJames Wright
117bd882c8aSJames Wright if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
118bd882c8aSJames Wright for (CeedInt z = 0; z < P_1D; z++) {
119bd882c8aSJames Wright const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D;
120bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem;
121bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
122bd882c8aSJames Wright r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp];
123bd882c8aSJames Wright }
124bd882c8aSJames Wright }
125bd882c8aSJames Wright }
126bd882c8aSJames Wright }
127bd882c8aSJames Wright
128bd882c8aSJames Wright //------------------------------------------------------------------------------
129bd882c8aSJames Wright // Single element -> E-vector
130bd882c8aSJames Wright //------------------------------------------------------------------------------
WriteElementStrided3d(const CeedInt NUM_COMP,const CeedInt P_1D,const CeedInt num_elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,private const CeedScalar * restrict r_v,global CeedScalar * restrict d_v)131bd882c8aSJames Wright inline void WriteElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
132bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v,
133bd882c8aSJames Wright global CeedScalar *restrict d_v) {
134bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0);
135bd882c8aSJames Wright const CeedInt item_id_y = get_local_id(1);
136bd882c8aSJames Wright const CeedInt elem = get_global_id(2);
137bd882c8aSJames Wright
138bd882c8aSJames Wright if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
139bd882c8aSJames Wright for (CeedInt z = 0; z < P_1D; z++) {
140bd882c8aSJames Wright const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D;
141bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem;
142bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
143bd882c8aSJames Wright d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D];
144bd882c8aSJames Wright }
145bd882c8aSJames Wright }
146bd882c8aSJames Wright }
147bd882c8aSJames Wright }
148