xref: /libCEED/include/ceed/jit-source/cuda/cuda-shared-basis-read-write-templates.h (revision fc0f7cc68128f3536a834f19d72828f4c59a4439)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for CUDA shared memory basis read/write templates
10 
11 #include <ceed.h>
12 
13 //------------------------------------------------------------------------------
14 // 1D
15 //------------------------------------------------------------------------------
16 
17 //------------------------------------------------------------------------------
18 // E-vector -> single element
19 //------------------------------------------------------------------------------
20 template <int NUM_COMP, int P_1D>
21 inline __device__ void ReadElementStrided1d(SharedData_Cuda &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
22                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
23   if (data.t_id_x < P_1D) {
24     const CeedInt node = data.t_id_x;
25     const CeedInt ind  = node * strides_node + elem * strides_elem;
26 
27     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
28       r_u[comp] = d_u[ind + comp * strides_comp];
29     }
30   }
31 }
32 
33 //------------------------------------------------------------------------------
34 // Single element -> E-vector
35 //------------------------------------------------------------------------------
36 template <int NUM_COMP, int P_1D>
37 inline __device__ void WriteElementStrided1d(SharedData_Cuda &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
38                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
39   if (data.t_id_x < P_1D) {
40     const CeedInt node = data.t_id_x;
41     const CeedInt ind  = node * strides_node + elem * strides_elem;
42 
43     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
44       d_v[ind + comp * strides_comp] = r_v[comp];
45     }
46   }
47 }
48 
49 //------------------------------------------------------------------------------
50 // 2D
51 //------------------------------------------------------------------------------
52 
53 //------------------------------------------------------------------------------
54 // E-vector -> single element
55 //------------------------------------------------------------------------------
56 template <int NUM_COMP, int P_1D>
57 inline __device__ void ReadElementStrided2d(SharedData_Cuda &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
58                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
59   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
60     const CeedInt node = data.t_id_x + data.t_id_y * P_1D;
61     const CeedInt ind  = node * strides_node + elem * strides_elem;
62 
63     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
64       r_u[comp] = d_u[ind + comp * strides_comp];
65     }
66   }
67 }
68 
69 //------------------------------------------------------------------------------
70 // Single element -> E-vector
71 //------------------------------------------------------------------------------
72 template <int NUM_COMP, int P_1D>
73 inline __device__ void WriteElementStrided2d(SharedData_Cuda &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
74                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
75   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
76     const CeedInt node = data.t_id_x + data.t_id_y * P_1D;
77     const CeedInt ind  = node * strides_node + elem * strides_elem;
78 
79     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
80       d_v[ind + comp * strides_comp] = r_v[comp];
81     }
82   }
83 }
84 
85 //------------------------------------------------------------------------------
86 // 3D
87 //------------------------------------------------------------------------------
88 
89 //------------------------------------------------------------------------------
90 // E-vector -> single element
91 //------------------------------------------------------------------------------
92 template <int NUM_COMP, int P_1D>
93 inline __device__ void ReadElementStrided3d(SharedData_Cuda &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
94                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
95   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
96     for (CeedInt z = 0; z < P_1D; z++) {
97       const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D;
98       const CeedInt ind  = node * strides_node + elem * strides_elem;
99 
100       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
101         r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp];
102       }
103     }
104   }
105 }
106 
107 //------------------------------------------------------------------------------
108 // Single element -> E-vector
109 //------------------------------------------------------------------------------
110 template <int NUM_COMP, int P_1D>
111 inline __device__ void WriteElementStrided3d(SharedData_Cuda &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
112                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
113   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
114     for (CeedInt z = 0; z < P_1D; z++) {
115       const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D;
116       const CeedInt ind  = node * strides_node + elem * strides_elem;
117 
118       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
119         d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D];
120       }
121     }
122   }
123 }
124