1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED: http://github.com/ceed
7
8 /// @file
9 /// Internal header for SYCL shared memory basis read/write templates
10 #include <ceed/types.h>
11
12 //------------------------------------------------------------------------------
13 // Helper function: load matrices for basis actions
14 //------------------------------------------------------------------------------
loadMatrix(const CeedInt N,const CeedScalar * restrict d_B,CeedScalar * restrict B)15 inline void loadMatrix(const CeedInt N, const CeedScalar *restrict d_B, CeedScalar *restrict B) {
16 const CeedInt item_id = get_local_linear_id();
17 const CeedInt group_size = get_local_size(0) * get_local_size(1) * get_local_size(2);
18 for (CeedInt i = item_id; i < N; i += group_size) B[i] = d_B[i];
19 }
20
21 //------------------------------------------------------------------------------
22 // 1D
23 //------------------------------------------------------------------------------
24
25 //------------------------------------------------------------------------------
26 // E-vector -> single element
27 //------------------------------------------------------------------------------
ReadElementStrided1d(const CeedInt NUM_COMP,const CeedInt P_1D,const CeedInt num_elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,global const CeedScalar * restrict d_u,private CeedScalar * restrict r_u)28 inline void ReadElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
29 const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u,
30 private CeedScalar *restrict r_u) {
31 const CeedInt item_id_x = get_local_id(0);
32 const CeedInt elem = get_global_id(2);
33
34 if (item_id_x < P_1D && elem < num_elem) {
35 const CeedInt node = item_id_x;
36 const CeedInt ind = node * strides_node + elem * strides_elem;
37 for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
38 r_u[comp] = d_u[ind + comp * strides_comp];
39 }
40 }
41 }
42
43 //------------------------------------------------------------------------------
44 // Single element -> E-vector
45 //------------------------------------------------------------------------------
WriteElementStrided1d(const CeedInt NUM_COMP,const CeedInt P_1D,const CeedInt num_elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,private const CeedScalar * restrict r_v,global CeedScalar * restrict d_v)46 inline void WriteElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
47 const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v,
48 global CeedScalar *restrict d_v) {
49 const CeedInt item_id_x = get_local_id(0);
50 const CeedInt elem = get_global_id(2);
51
52 if (item_id_x < P_1D && elem < num_elem) {
53 const CeedInt node = item_id_x;
54 const CeedInt ind = node * strides_node + elem * strides_elem;
55 for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
56 d_v[ind + comp * strides_comp] = r_v[comp];
57 }
58 }
59 }
60
61 //------------------------------------------------------------------------------
62 // 2D
63 //------------------------------------------------------------------------------
64
65 //------------------------------------------------------------------------------
66 // E-vector -> single element
67 //------------------------------------------------------------------------------
ReadElementStrided2d(const CeedInt NUM_COMP,const CeedInt P_1D,const CeedInt num_elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,global const CeedScalar * restrict d_u,private CeedScalar * restrict r_u)68 inline void ReadElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
69 const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u,
70 private CeedScalar *restrict r_u) {
71 const CeedInt item_id_x = get_local_id(0);
72 const CeedInt item_id_y = get_local_id(1);
73 const CeedInt elem = get_global_id(2);
74
75 if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
76 const CeedInt node = item_id_x + item_id_y * P_1D;
77 const CeedInt ind = node * strides_node + elem * strides_elem;
78 for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
79 r_u[comp] = d_u[ind + comp * strides_comp];
80 }
81 }
82 }
83
84 //------------------------------------------------------------------------------
85 // Single element -> E-vector
86 //------------------------------------------------------------------------------
WriteElementStrided2d(const CeedInt NUM_COMP,const CeedInt P_1D,const CeedInt num_elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,private const CeedScalar * restrict r_v,global CeedScalar * restrict d_v)87 inline void WriteElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
88 const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v,
89 global CeedScalar *restrict d_v) {
90 const CeedInt item_id_x = get_local_id(0);
91 const CeedInt item_id_y = get_local_id(1);
92 const CeedInt elem = get_global_id(2);
93
94 if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
95 const CeedInt node = item_id_x + item_id_y * P_1D;
96 const CeedInt ind = node * strides_node + elem * strides_elem;
97 for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
98 d_v[ind + comp * strides_comp] = r_v[comp];
99 }
100 }
101 }
102
103 //------------------------------------------------------------------------------
104 // 3D
105 //------------------------------------------------------------------------------
106
107 //------------------------------------------------------------------------------
108 // E-vector -> single element
109 //------------------------------------------------------------------------------
ReadElementStrided3d(const CeedInt NUM_COMP,const CeedInt P_1D,const CeedInt num_elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,global const CeedScalar * restrict d_u,private CeedScalar * restrict r_u)110 inline void ReadElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
111 const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u,
112 private CeedScalar *restrict r_u) {
113 const CeedInt item_id_x = get_local_id(0);
114 const CeedInt item_id_y = get_local_id(1);
115 const CeedInt elem = get_global_id(2);
116
117 if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
118 for (CeedInt z = 0; z < P_1D; z++) {
119 const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D;
120 const CeedInt ind = node * strides_node + elem * strides_elem;
121 for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
122 r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp];
123 }
124 }
125 }
126 }
127
128 //------------------------------------------------------------------------------
129 // Single element -> E-vector
130 //------------------------------------------------------------------------------
WriteElementStrided3d(const CeedInt NUM_COMP,const CeedInt P_1D,const CeedInt num_elem,const CeedInt strides_node,const CeedInt strides_comp,const CeedInt strides_elem,private const CeedScalar * restrict r_v,global CeedScalar * restrict d_v)131 inline void WriteElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
132 const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v,
133 global CeedScalar *restrict d_v) {
134 const CeedInt item_id_x = get_local_id(0);
135 const CeedInt item_id_y = get_local_id(1);
136 const CeedInt elem = get_global_id(2);
137
138 if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
139 for (CeedInt z = 0; z < P_1D; z++) {
140 const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D;
141 const CeedInt ind = node * strides_node + elem * strides_elem;
142 for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
143 d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D];
144 }
145 }
146 }
147 }
148