xref: /libCEED/include/ceed/jit-source/sycl/sycl-shared-basis-read-write-templates.h (revision 4f69910b6e3819988a1446e35e0e85e74672bc23)
1 // Copyright (c) 2017-2025, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for SYCL shared memory basis read/write templates
10 #include <ceed/types.h>
11 
12 //------------------------------------------------------------------------------
13 // Helper function: load matrices for basis actions
14 //------------------------------------------------------------------------------
15 inline void loadMatrix(const CeedInt N, const CeedScalar *restrict d_B, CeedScalar *restrict B) {
16   const CeedInt item_id    = get_local_linear_id();
17   const CeedInt group_size = get_local_size(0) * get_local_size(1) * get_local_size(2);
18   for (CeedInt i = item_id; i < N; i += group_size) B[i] = d_B[i];
19 }
20 
21 //------------------------------------------------------------------------------
22 // 1D
23 //------------------------------------------------------------------------------
24 
25 //------------------------------------------------------------------------------
26 // E-vector -> single element
27 //------------------------------------------------------------------------------
28 inline void ReadElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
29                                  const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u,
30                                  private CeedScalar *restrict r_u) {
31   const CeedInt item_id_x = get_local_id(0);
32   const CeedInt elem      = get_global_id(2);
33 
34   if (item_id_x < P_1D && elem < num_elem) {
35     const CeedInt node = item_id_x;
36     const CeedInt ind  = node * strides_node + elem * strides_elem;
37     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
38       r_u[comp] = d_u[ind + comp * strides_comp];
39     }
40   }
41 }
42 
43 //------------------------------------------------------------------------------
44 // Single element -> E-vector
45 //------------------------------------------------------------------------------
46 inline void WriteElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
47                                   const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v,
48                                   global CeedScalar *restrict d_v) {
49   const CeedInt item_id_x = get_local_id(0);
50   const CeedInt elem      = get_global_id(2);
51 
52   if (item_id_x < P_1D && elem < num_elem) {
53     const CeedInt node = item_id_x;
54     const CeedInt ind  = node * strides_node + elem * strides_elem;
55     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
56       d_v[ind + comp * strides_comp] = r_v[comp];
57     }
58   }
59 }
60 
61 //------------------------------------------------------------------------------
62 // 2D
63 //------------------------------------------------------------------------------
64 
65 //------------------------------------------------------------------------------
66 // E-vector -> single element
67 //------------------------------------------------------------------------------
68 inline void ReadElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
69                                  const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u,
70                                  private CeedScalar *restrict r_u) {
71   const CeedInt item_id_x = get_local_id(0);
72   const CeedInt item_id_y = get_local_id(1);
73   const CeedInt elem      = get_global_id(2);
74 
75   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
76     const CeedInt node = item_id_x + item_id_y * P_1D;
77     const CeedInt ind  = node * strides_node + elem * strides_elem;
78     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
79       r_u[comp] = d_u[ind + comp * strides_comp];
80     }
81   }
82 }
83 
84 //------------------------------------------------------------------------------
85 // Single element -> E-vector
86 //------------------------------------------------------------------------------
87 inline void WriteElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
88                                   const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v,
89                                   global CeedScalar *restrict d_v) {
90   const CeedInt item_id_x = get_local_id(0);
91   const CeedInt item_id_y = get_local_id(1);
92   const CeedInt elem      = get_global_id(2);
93 
94   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
95     const CeedInt node = item_id_x + item_id_y * P_1D;
96     const CeedInt ind  = node * strides_node + elem * strides_elem;
97     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
98       d_v[ind + comp * strides_comp] = r_v[comp];
99     }
100   }
101 }
102 
103 //------------------------------------------------------------------------------
104 // 3D
105 //------------------------------------------------------------------------------
106 
107 //------------------------------------------------------------------------------
108 // E-vector -> single element
109 //------------------------------------------------------------------------------
110 inline void ReadElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
111                                  const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u,
112                                  private CeedScalar *restrict r_u) {
113   const CeedInt item_id_x = get_local_id(0);
114   const CeedInt item_id_y = get_local_id(1);
115   const CeedInt elem      = get_global_id(2);
116 
117   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
118     for (CeedInt z = 0; z < P_1D; z++) {
119       const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D;
120       const CeedInt ind  = node * strides_node + elem * strides_elem;
121       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
122         r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp];
123       }
124     }
125   }
126 }
127 
128 //------------------------------------------------------------------------------
129 // Single element -> E-vector
130 //------------------------------------------------------------------------------
131 inline void WriteElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
132                                   const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v,
133                                   global CeedScalar *restrict d_v) {
134   const CeedInt item_id_x = get_local_id(0);
135   const CeedInt item_id_y = get_local_id(1);
136   const CeedInt elem      = get_global_id(2);
137 
138   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
139     for (CeedInt z = 0; z < P_1D; z++) {
140       const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D;
141       const CeedInt ind  = node * strides_node + elem * strides_elem;
142       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
143         d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D];
144       }
145     }
146   }
147 }
148