xref: /libCEED/include/ceed/jit-source/sycl/sycl-shared-basis-read-write-templates.h (revision edb2538e3dd6743c029967fc4e89c6fcafedb8c2)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for SYCL shared memory basis read/write templates
10 #ifndef CEED_SYCL_SHARED_BASIS_READ_WRITE_TEMPLATES_H
11 #define CEED_SYCL_SHARED_BASIS_READ_WRITE_TEMPLATES_H
12 
13 #include <ceed.h>
14 #include "sycl-types.h"
15 
16 //------------------------------------------------------------------------------
17 // Helper function: load matrices for basis actions
18 //------------------------------------------------------------------------------
19 inline void loadMatrix(const CeedInt N, const CeedScalar* restrict d_B, CeedScalar* restrict B) {
20   const CeedInt item_id    = get_local_linear_id();
21   const CeedInt group_size = get_local_size(0) * get_local_size(1) * get_local_size(2);
22   for (CeedInt i = item_id; i < N; i += group_size) B[i] = d_B[i];
23 }
24 
25 //------------------------------------------------------------------------------
26 // 1D
27 //------------------------------------------------------------------------------
28 
29 //------------------------------------------------------------------------------
30 // E-vector -> single element
31 //------------------------------------------------------------------------------
32 inline void ReadElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
33                                  const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar* restrict d_u,
34                                  private CeedScalar* restrict r_u) {
35   const CeedInt item_id_x = get_local_id(0);
36   const CeedInt elem      = get_global_id(2);
37 
38   if (item_id_x < P_1D && elem < num_elem) {
39     const CeedInt node = item_id_x;
40     const CeedInt ind  = node * strides_node + elem * strides_elem;
41     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
42       r_u[comp] = d_u[ind + comp * strides_comp];
43     }
44   }
45 }
46 
47 //------------------------------------------------------------------------------
48 // Single element -> E-vector
49 //------------------------------------------------------------------------------
50 inline void WriteElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
51                                   const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar* restrict r_v,
52                                   global CeedScalar* restrict d_v) {
53   const CeedInt item_id_x = get_local_id(0);
54   const CeedInt elem      = get_global_id(2);
55 
56   if (item_id_x < P_1D && elem < num_elem) {
57     const CeedInt node = item_id_x;
58     const CeedInt ind  = node * strides_node + elem * strides_elem;
59     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
60       d_v[ind + comp * strides_comp] = r_v[comp];
61     }
62   }
63 }
64 
65 //------------------------------------------------------------------------------
66 // 2D
67 //------------------------------------------------------------------------------
68 
69 //------------------------------------------------------------------------------
70 // E-vector -> single element
71 //------------------------------------------------------------------------------
72 inline void ReadElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
73                                  const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar* restrict d_u,
74                                  private CeedScalar* restrict r_u) {
75   const CeedInt item_id_x = get_local_id(0);
76   const CeedInt item_id_y = get_local_id(1);
77   const CeedInt elem      = get_global_id(2);
78 
79   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
80     const CeedInt node = item_id_x + item_id_y * P_1D;
81     const CeedInt ind  = node * strides_node + elem * strides_elem;
82     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
83       r_u[comp] = d_u[ind + comp * strides_comp];
84     }
85   }
86 }
87 
88 //------------------------------------------------------------------------------
89 // Single element -> E-vector
90 //------------------------------------------------------------------------------
91 inline void WriteElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
92                                   const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar* restrict r_v,
93                                   global CeedScalar* restrict d_v) {
94   const CeedInt item_id_x = get_local_id(0);
95   const CeedInt item_id_y = get_local_id(1);
96   const CeedInt elem      = get_global_id(2);
97 
98   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
99     const CeedInt node = item_id_x + item_id_y * P_1D;
100     const CeedInt ind  = node * strides_node + elem * strides_elem;
101     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
102       d_v[ind + comp * strides_comp] = r_v[comp];
103     }
104   }
105 }
106 
107 //------------------------------------------------------------------------------
108 // 3D
109 //------------------------------------------------------------------------------
110 
111 //------------------------------------------------------------------------------
112 // E-vector -> single element
113 //------------------------------------------------------------------------------
114 inline void ReadElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
115                                  const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar* restrict d_u,
116                                  private CeedScalar* restrict r_u) {
117   const CeedInt item_id_x = get_local_id(0);
118   const CeedInt item_id_y = get_local_id(1);
119   const CeedInt elem      = get_global_id(2);
120 
121   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
122     for (CeedInt z = 0; z < P_1D; z++) {
123       const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D;
124       const CeedInt ind  = node * strides_node + elem * strides_elem;
125       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
126         r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp];
127       }
128     }
129   }
130 }
131 
132 //------------------------------------------------------------------------------
133 // Single element -> E-vector
134 //------------------------------------------------------------------------------
135 inline void WriteElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node,
136                                   const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar* restrict r_v,
137                                   global CeedScalar* restrict d_v) {
138   const CeedInt item_id_x = get_local_id(0);
139   const CeedInt item_id_y = get_local_id(1);
140   const CeedInt elem      = get_global_id(2);
141 
142   if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
143     for (CeedInt z = 0; z < P_1D; z++) {
144       const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D;
145       const CeedInt ind  = node * strides_node + elem * strides_elem;
146       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
147         d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D];
148       }
149     }
150   }
151 }
152 
153 //------------------------------------------------------------------------------
154 
155 #endif  // CEED_SYCL_SHARED_BASIS_READ_WRITE_TEMPLATES_H
156