xref: /libCEED/include/ceed/jit-source/sycl/sycl-shared-basis-read-write-templates.h (revision 5ebd836c59d60a2e5e1cb67f6731404c7da26f85)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for SYCL shared memory basis read/write templates
10 
11 #include <ceed.h>
12 #include "sycl-types.h"
13 
14 //------------------------------------------------------------------------------
15 // Helper function: load matrices for basis actions
16 //------------------------------------------------------------------------------
17 inline void loadMatrix(const CeedInt N, const CeedScalar *restrict d_B, CeedScalar *restrict B) {
18   const CeedInt item_id    = get_local_linear_id();
19   const CeedInt group_size = get_local_size(0) * get_local_size(1) * get_local_size(2);
20   for (CeedInt i = item_id; i < N; i += group_size) B[i] = d_B[i];
21 }
22 
23 //------------------------------------------------------------------------------
24 // 1D
25 //------------------------------------------------------------------------------
26 
27 //------------------------------------------------------------------------------
28 // E-vector -> single element
29 //------------------------------------------------------------------------------
30 inline void ReadElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
31                                  const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u,
32                                  private CeedScalar *restrict r_u) {
33   const CeedInt item_id_x = get_local_id(0);
34   const CeedInt elem      = get_global_id(2);
35 
36   if (item_id_x < P_1D && elem < num_elem) {
37     const CeedInt node = item_id_x;
38     const CeedInt ind  = node * strides_node + elem * strides_elem;
39     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
40       r_u[comp] = d_u[ind + comp * strides_comp];
41     }
42   }
43 }
44 
45 //------------------------------------------------------------------------------
46 // Single element -> E-vector
47 //------------------------------------------------------------------------------
48 inline void WriteElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
49                                   const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v,
50                                   global CeedScalar *restrict d_v) {
51   const CeedInt item_id_x = get_local_id(0);
52   const CeedInt elem      = get_global_id(2);
53 
54   if (item_id_x < P_1D && elem < num_elem) {
55     const CeedInt node = item_id_x;
56     const CeedInt ind  = node * strides_node + elem * strides_elem;
57     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
58       d_v[ind + comp * strides_comp] = r_v[comp];
59     }
60   }
61 }
62 
63 //------------------------------------------------------------------------------
64 // 2D
65 //------------------------------------------------------------------------------
66 
67 //------------------------------------------------------------------------------
68 // E-vector -> single element
69 //------------------------------------------------------------------------------
70 inline void ReadElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
71                                  const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u,
72                                  private CeedScalar *restrict r_u) {
73   const CeedInt item_id_x = get_local_id(0);
74   const CeedInt item_id_y = get_local_id(1);
75   const CeedInt elem      = get_global_id(2);
76 
77   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
78     const CeedInt node = item_id_x + item_id_y * P_1D;
79     const CeedInt ind  = node * strides_node + elem * strides_elem;
80     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
81       r_u[comp] = d_u[ind + comp * strides_comp];
82     }
83   }
84 }
85 
86 //------------------------------------------------------------------------------
87 // Single element -> E-vector
88 //------------------------------------------------------------------------------
89 inline void WriteElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
90                                   const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v,
91                                   global CeedScalar *restrict d_v) {
92   const CeedInt item_id_x = get_local_id(0);
93   const CeedInt item_id_y = get_local_id(1);
94   const CeedInt elem      = get_global_id(2);
95 
96   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
97     const CeedInt node = item_id_x + item_id_y * P_1D;
98     const CeedInt ind  = node * strides_node + elem * strides_elem;
99     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
100       d_v[ind + comp * strides_comp] = r_v[comp];
101     }
102   }
103 }
104 
105 //------------------------------------------------------------------------------
106 // 3D
107 //------------------------------------------------------------------------------
108 
109 //------------------------------------------------------------------------------
110 // E-vector -> single element
111 //------------------------------------------------------------------------------
112 inline void ReadElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
113                                  const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u,
114                                  private CeedScalar *restrict r_u) {
115   const CeedInt item_id_x = get_local_id(0);
116   const CeedInt item_id_y = get_local_id(1);
117   const CeedInt elem      = get_global_id(2);
118 
119   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
120     for (CeedInt z = 0; z < P_1D; z++) {
121       const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D;
122       const CeedInt ind  = node * strides_node + elem * strides_elem;
123       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
124         r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp];
125       }
126     }
127   }
128 }
129 
130 //------------------------------------------------------------------------------
131 // Single element -> E-vector
132 //------------------------------------------------------------------------------
133 inline void WriteElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
134                                   const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v,
135                                   global CeedScalar *restrict d_v) {
136   const CeedInt item_id_x = get_local_id(0);
137   const CeedInt item_id_y = get_local_id(1);
138   const CeedInt elem      = get_global_id(2);
139 
140   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
141     for (CeedInt z = 0; z < P_1D; z++) {
142       const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D;
143       const CeedInt ind  = node * strides_node + elem * strides_elem;
144       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
145         d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D];
146       }
147     }
148   }
149 }
150