xref: /libCEED/include/ceed/jit-source/sycl/sycl-shared-basis-read-write-templates.h (revision c0b5abf0f23b15c4f0ada76f8abe9f8d2b6fa247)
15aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3bd882c8aSJames Wright //
4bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause
5bd882c8aSJames Wright //
6bd882c8aSJames Wright // This file is part of CEED:  http://github.com/ceed
7bd882c8aSJames Wright 
8bd882c8aSJames Wright /// @file
9bd882c8aSJames Wright /// Internal header for SYCL shared memory basis read/write templates
10bd882c8aSJames Wright 
11*c0b5abf0SJeremy L Thompson #include <ceed/types.h>
12bd882c8aSJames Wright #include "sycl-types.h"
13bd882c8aSJames Wright 
14bd882c8aSJames Wright //------------------------------------------------------------------------------
15bd882c8aSJames Wright // Helper function: load matrices for basis actions
16bd882c8aSJames Wright //------------------------------------------------------------------------------
17bd882c8aSJames Wright inline void loadMatrix(const CeedInt N, const CeedScalar *restrict d_B, CeedScalar *restrict B) {
18bd882c8aSJames Wright   const CeedInt item_id    = get_local_linear_id();
19bd882c8aSJames Wright   const CeedInt group_size = get_local_size(0) * get_local_size(1) * get_local_size(2);
20bd882c8aSJames Wright   for (CeedInt i = item_id; i < N; i += group_size) B[i] = d_B[i];
21bd882c8aSJames Wright }
22bd882c8aSJames Wright 
23bd882c8aSJames Wright //------------------------------------------------------------------------------
24bd882c8aSJames Wright // 1D
25bd882c8aSJames Wright //------------------------------------------------------------------------------
26bd882c8aSJames Wright 
27bd882c8aSJames Wright //------------------------------------------------------------------------------
28bd882c8aSJames Wright // E-vector -> single element
29bd882c8aSJames Wright //------------------------------------------------------------------------------
30bd882c8aSJames Wright inline void ReadElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
31bd882c8aSJames Wright                                  const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u,
32bd882c8aSJames Wright                                  private CeedScalar *restrict r_u) {
33bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
34bd882c8aSJames Wright   const CeedInt elem      = get_global_id(2);
35bd882c8aSJames Wright 
36bd882c8aSJames Wright   if (item_id_x < P_1D && elem < num_elem) {
37bd882c8aSJames Wright     const CeedInt node = item_id_x;
38bd882c8aSJames Wright     const CeedInt ind  = node * strides_node + elem * strides_elem;
39bd882c8aSJames Wright     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
40bd882c8aSJames Wright       r_u[comp] = d_u[ind + comp * strides_comp];
41bd882c8aSJames Wright     }
42bd882c8aSJames Wright   }
43bd882c8aSJames Wright }
44bd882c8aSJames Wright 
45bd882c8aSJames Wright //------------------------------------------------------------------------------
46bd882c8aSJames Wright // Single element -> E-vector
47bd882c8aSJames Wright //------------------------------------------------------------------------------
48bd882c8aSJames Wright inline void WriteElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
49bd882c8aSJames Wright                                   const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v,
50bd882c8aSJames Wright                                   global CeedScalar *restrict d_v) {
51bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
52bd882c8aSJames Wright   const CeedInt elem      = get_global_id(2);
53bd882c8aSJames Wright 
54bd882c8aSJames Wright   if (item_id_x < P_1D && elem < num_elem) {
55bd882c8aSJames Wright     const CeedInt node = item_id_x;
56bd882c8aSJames Wright     const CeedInt ind  = node * strides_node + elem * strides_elem;
57bd882c8aSJames Wright     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
58bd882c8aSJames Wright       d_v[ind + comp * strides_comp] = r_v[comp];
59bd882c8aSJames Wright     }
60bd882c8aSJames Wright   }
61bd882c8aSJames Wright }
62bd882c8aSJames Wright 
63bd882c8aSJames Wright //------------------------------------------------------------------------------
64bd882c8aSJames Wright // 2D
65bd882c8aSJames Wright //------------------------------------------------------------------------------
66bd882c8aSJames Wright 
67bd882c8aSJames Wright //------------------------------------------------------------------------------
68bd882c8aSJames Wright // E-vector -> single element
69bd882c8aSJames Wright //------------------------------------------------------------------------------
70bd882c8aSJames Wright inline void ReadElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
71bd882c8aSJames Wright                                  const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u,
72bd882c8aSJames Wright                                  private CeedScalar *restrict r_u) {
73bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
74bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
75bd882c8aSJames Wright   const CeedInt elem      = get_global_id(2);
76bd882c8aSJames Wright 
77bd882c8aSJames Wright   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
78bd882c8aSJames Wright     const CeedInt node = item_id_x + item_id_y * P_1D;
79bd882c8aSJames Wright     const CeedInt ind  = node * strides_node + elem * strides_elem;
80bd882c8aSJames Wright     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
81bd882c8aSJames Wright       r_u[comp] = d_u[ind + comp * strides_comp];
82bd882c8aSJames Wright     }
83bd882c8aSJames Wright   }
84bd882c8aSJames Wright }
85bd882c8aSJames Wright 
86bd882c8aSJames Wright //------------------------------------------------------------------------------
87bd882c8aSJames Wright // Single element -> E-vector
88bd882c8aSJames Wright //------------------------------------------------------------------------------
89bd882c8aSJames Wright inline void WriteElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
90bd882c8aSJames Wright                                   const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v,
91bd882c8aSJames Wright                                   global CeedScalar *restrict d_v) {
92bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
93bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
94bd882c8aSJames Wright   const CeedInt elem      = get_global_id(2);
95bd882c8aSJames Wright 
96bd882c8aSJames Wright   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
97bd882c8aSJames Wright     const CeedInt node = item_id_x + item_id_y * P_1D;
98bd882c8aSJames Wright     const CeedInt ind  = node * strides_node + elem * strides_elem;
99bd882c8aSJames Wright     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
100bd882c8aSJames Wright       d_v[ind + comp * strides_comp] = r_v[comp];
101bd882c8aSJames Wright     }
102bd882c8aSJames Wright   }
103bd882c8aSJames Wright }
104bd882c8aSJames Wright 
105bd882c8aSJames Wright //------------------------------------------------------------------------------
106bd882c8aSJames Wright // 3D
107bd882c8aSJames Wright //------------------------------------------------------------------------------
108bd882c8aSJames Wright 
109bd882c8aSJames Wright //------------------------------------------------------------------------------
110bd882c8aSJames Wright // E-vector -> single element
111bd882c8aSJames Wright //------------------------------------------------------------------------------
112bd882c8aSJames Wright inline void ReadElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
113bd882c8aSJames Wright                                  const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u,
114bd882c8aSJames Wright                                  private CeedScalar *restrict r_u) {
115bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
116bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
117bd882c8aSJames Wright   const CeedInt elem      = get_global_id(2);
118bd882c8aSJames Wright 
119bd882c8aSJames Wright   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
120bd882c8aSJames Wright     for (CeedInt z = 0; z < P_1D; z++) {
121bd882c8aSJames Wright       const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D;
122bd882c8aSJames Wright       const CeedInt ind  = node * strides_node + elem * strides_elem;
123bd882c8aSJames Wright       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
124bd882c8aSJames Wright         r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp];
125bd882c8aSJames Wright       }
126bd882c8aSJames Wright     }
127bd882c8aSJames Wright   }
128bd882c8aSJames Wright }
129bd882c8aSJames Wright 
130bd882c8aSJames Wright //------------------------------------------------------------------------------
131bd882c8aSJames Wright // Single element -> E-vector
132bd882c8aSJames Wright //------------------------------------------------------------------------------
133bd882c8aSJames Wright inline void WriteElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
134bd882c8aSJames Wright                                   const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v,
135bd882c8aSJames Wright                                   global CeedScalar *restrict d_v) {
136bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
137bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
138bd882c8aSJames Wright   const CeedInt elem      = get_global_id(2);
139bd882c8aSJames Wright 
140bd882c8aSJames Wright   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
141bd882c8aSJames Wright     for (CeedInt z = 0; z < P_1D; z++) {
142bd882c8aSJames Wright       const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D;
143bd882c8aSJames Wright       const CeedInt ind  = node * strides_node + elem * strides_elem;
144bd882c8aSJames Wright       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
145bd882c8aSJames Wright         d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D];
146bd882c8aSJames Wright       }
147bd882c8aSJames Wright     }
148bd882c8aSJames Wright   }
149bd882c8aSJames Wright }
150