xref: /libCEED/include/ceed/jit-source/sycl/sycl-shared-basis-read-write-templates.h (revision 5aed82e4fa97acf4ba24a7f10a35f5303a6798e0)
1*5aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3bd882c8aSJames Wright //
4bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause
5bd882c8aSJames Wright //
6bd882c8aSJames Wright // This file is part of CEED:  http://github.com/ceed
7bd882c8aSJames Wright 
8bd882c8aSJames Wright /// @file
9bd882c8aSJames Wright /// Internal header for SYCL shared memory basis read/write templates
1094b7b29bSJeremy L Thompson #ifndef CEED_SYCL_SHARED_BASIS_READ_WRITE_TEMPLATES_H
1194b7b29bSJeremy L Thompson #define CEED_SYCL_SHARED_BASIS_READ_WRITE_TEMPLATES_H
12bd882c8aSJames Wright 
13bd882c8aSJames Wright #include <ceed.h>
14bd882c8aSJames Wright #include "sycl-types.h"
15bd882c8aSJames Wright 
16bd882c8aSJames Wright //------------------------------------------------------------------------------
17bd882c8aSJames Wright // Helper function: load matrices for basis actions
18bd882c8aSJames Wright //------------------------------------------------------------------------------
19bd882c8aSJames Wright inline void loadMatrix(const CeedInt N, const CeedScalar *restrict d_B, CeedScalar *restrict B) {
20bd882c8aSJames Wright   const CeedInt item_id    = get_local_linear_id();
21bd882c8aSJames Wright   const CeedInt group_size = get_local_size(0) * get_local_size(1) * get_local_size(2);
22bd882c8aSJames Wright   for (CeedInt i = item_id; i < N; i += group_size) B[i] = d_B[i];
23bd882c8aSJames Wright }
24bd882c8aSJames Wright 
25bd882c8aSJames Wright //------------------------------------------------------------------------------
26bd882c8aSJames Wright // 1D
27bd882c8aSJames Wright //------------------------------------------------------------------------------
28bd882c8aSJames Wright 
29bd882c8aSJames Wright //------------------------------------------------------------------------------
30bd882c8aSJames Wright // E-vector -> single element
31bd882c8aSJames Wright //------------------------------------------------------------------------------
32bd882c8aSJames Wright inline void ReadElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
33bd882c8aSJames Wright                                  const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u,
34bd882c8aSJames Wright                                  private CeedScalar *restrict r_u) {
35bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
36bd882c8aSJames Wright   const CeedInt elem      = get_global_id(2);
37bd882c8aSJames Wright 
38bd882c8aSJames Wright   if (item_id_x < P_1D && elem < num_elem) {
39bd882c8aSJames Wright     const CeedInt node = item_id_x;
40bd882c8aSJames Wright     const CeedInt ind  = node * strides_node + elem * strides_elem;
41bd882c8aSJames Wright     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
42bd882c8aSJames Wright       r_u[comp] = d_u[ind + comp * strides_comp];
43bd882c8aSJames Wright     }
44bd882c8aSJames Wright   }
45bd882c8aSJames Wright }
46bd882c8aSJames Wright 
47bd882c8aSJames Wright //------------------------------------------------------------------------------
48bd882c8aSJames Wright // Single element -> E-vector
49bd882c8aSJames Wright //------------------------------------------------------------------------------
50bd882c8aSJames Wright inline void WriteElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
51bd882c8aSJames Wright                                   const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v,
52bd882c8aSJames Wright                                   global CeedScalar *restrict d_v) {
53bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
54bd882c8aSJames Wright   const CeedInt elem      = get_global_id(2);
55bd882c8aSJames Wright 
56bd882c8aSJames Wright   if (item_id_x < P_1D && elem < num_elem) {
57bd882c8aSJames Wright     const CeedInt node = item_id_x;
58bd882c8aSJames Wright     const CeedInt ind  = node * strides_node + elem * strides_elem;
59bd882c8aSJames Wright     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
60bd882c8aSJames Wright       d_v[ind + comp * strides_comp] = r_v[comp];
61bd882c8aSJames Wright     }
62bd882c8aSJames Wright   }
63bd882c8aSJames Wright }
64bd882c8aSJames Wright 
65bd882c8aSJames Wright //------------------------------------------------------------------------------
66bd882c8aSJames Wright // 2D
67bd882c8aSJames Wright //------------------------------------------------------------------------------
68bd882c8aSJames Wright 
69bd882c8aSJames Wright //------------------------------------------------------------------------------
70bd882c8aSJames Wright // E-vector -> single element
71bd882c8aSJames Wright //------------------------------------------------------------------------------
72bd882c8aSJames Wright inline void ReadElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
73bd882c8aSJames Wright                                  const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u,
74bd882c8aSJames Wright                                  private CeedScalar *restrict r_u) {
75bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
76bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
77bd882c8aSJames Wright   const CeedInt elem      = get_global_id(2);
78bd882c8aSJames Wright 
79bd882c8aSJames Wright   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
80bd882c8aSJames Wright     const CeedInt node = item_id_x + item_id_y * P_1D;
81bd882c8aSJames Wright     const CeedInt ind  = node * strides_node + elem * strides_elem;
82bd882c8aSJames Wright     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
83bd882c8aSJames Wright       r_u[comp] = d_u[ind + comp * strides_comp];
84bd882c8aSJames Wright     }
85bd882c8aSJames Wright   }
86bd882c8aSJames Wright }
87bd882c8aSJames Wright 
88bd882c8aSJames Wright //------------------------------------------------------------------------------
89bd882c8aSJames Wright // Single element -> E-vector
90bd882c8aSJames Wright //------------------------------------------------------------------------------
91bd882c8aSJames Wright inline void WriteElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
92bd882c8aSJames Wright                                   const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v,
93bd882c8aSJames Wright                                   global CeedScalar *restrict d_v) {
94bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
95bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
96bd882c8aSJames Wright   const CeedInt elem      = get_global_id(2);
97bd882c8aSJames Wright 
98bd882c8aSJames Wright   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
99bd882c8aSJames Wright     const CeedInt node = item_id_x + item_id_y * P_1D;
100bd882c8aSJames Wright     const CeedInt ind  = node * strides_node + elem * strides_elem;
101bd882c8aSJames Wright     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
102bd882c8aSJames Wright       d_v[ind + comp * strides_comp] = r_v[comp];
103bd882c8aSJames Wright     }
104bd882c8aSJames Wright   }
105bd882c8aSJames Wright }
106bd882c8aSJames Wright 
107bd882c8aSJames Wright //------------------------------------------------------------------------------
108bd882c8aSJames Wright // 3D
109bd882c8aSJames Wright //------------------------------------------------------------------------------
110bd882c8aSJames Wright 
111bd882c8aSJames Wright //------------------------------------------------------------------------------
112bd882c8aSJames Wright // E-vector -> single element
113bd882c8aSJames Wright //------------------------------------------------------------------------------
114bd882c8aSJames Wright inline void ReadElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
115bd882c8aSJames Wright                                  const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u,
116bd882c8aSJames Wright                                  private CeedScalar *restrict r_u) {
117bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
118bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
119bd882c8aSJames Wright   const CeedInt elem      = get_global_id(2);
120bd882c8aSJames Wright 
121bd882c8aSJames Wright   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
122bd882c8aSJames Wright     for (CeedInt z = 0; z < P_1D; z++) {
123bd882c8aSJames Wright       const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D;
124bd882c8aSJames Wright       const CeedInt ind  = node * strides_node + elem * strides_elem;
125bd882c8aSJames Wright       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
126bd882c8aSJames Wright         r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp];
127bd882c8aSJames Wright       }
128bd882c8aSJames Wright     }
129bd882c8aSJames Wright   }
130bd882c8aSJames Wright }
131bd882c8aSJames Wright 
132bd882c8aSJames Wright //------------------------------------------------------------------------------
133bd882c8aSJames Wright // Single element -> E-vector
134bd882c8aSJames Wright //------------------------------------------------------------------------------
135bd882c8aSJames Wright inline void WriteElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
136bd882c8aSJames Wright                                   const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v,
137bd882c8aSJames Wright                                   global CeedScalar *restrict d_v) {
138bd882c8aSJames Wright   const CeedInt item_id_x = get_local_id(0);
139bd882c8aSJames Wright   const CeedInt item_id_y = get_local_id(1);
140bd882c8aSJames Wright   const CeedInt elem      = get_global_id(2);
141bd882c8aSJames Wright 
142bd882c8aSJames Wright   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
143bd882c8aSJames Wright     for (CeedInt z = 0; z < P_1D; z++) {
144bd882c8aSJames Wright       const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D;
145bd882c8aSJames Wright       const CeedInt ind  = node * strides_node + elem * strides_elem;
146bd882c8aSJames Wright       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
147bd882c8aSJames Wright         d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D];
148bd882c8aSJames Wright       }
149bd882c8aSJames Wright     }
150bd882c8aSJames Wright   }
151bd882c8aSJames Wright }
152bd882c8aSJames Wright 
153bd882c8aSJames Wright //------------------------------------------------------------------------------
154bd882c8aSJames Wright 
15594b7b29bSJeremy L Thompson #endif  // CEED_SYCL_SHARED_BASIS_READ_WRITE_TEMPLATES_H
156