xref: /libCEED/include/ceed/jit-source/cuda/cuda-shared-basis-read-write-templates.h (revision 4fee36f0a30516a0b5ad51bf7eb3b32d83efd623)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for CUDA shared memory basis read/write templates
10 #ifndef _ceed_cuda_shared_basis_read_write_templates_h
11 #define _ceed_cuda_shared_basis_read_write_templates_h
12 
13 #include <ceed.h>
14 
15 //------------------------------------------------------------------------------
16 // 1D
17 //------------------------------------------------------------------------------
18 
19 //------------------------------------------------------------------------------
20 // E-vector -> single element
21 //------------------------------------------------------------------------------
22 template <int NUM_COMP, int P_1D>
23 inline __device__ void ReadElementStrided1d(SharedData_Cuda &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
24                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
25   if (data.t_id_x < P_1D) {
26     const CeedInt node = data.t_id_x;
27     const CeedInt ind  = node * strides_node + elem * strides_elem;
28     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
29       r_u[comp] = d_u[ind + comp * strides_comp];
30     }
31   }
32 }
33 
34 //------------------------------------------------------------------------------
35 // Single element -> E-vector
36 //------------------------------------------------------------------------------
37 template <int NUM_COMP, int P_1D>
38 inline __device__ void WriteElementStrided1d(SharedData_Cuda &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
39                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
40   if (data.t_id_x < P_1D) {
41     const CeedInt node = data.t_id_x;
42     const CeedInt ind  = node * strides_node + elem * strides_elem;
43     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
44       d_v[ind + comp * strides_comp] = r_v[comp];
45     }
46   }
47 }
48 
49 //------------------------------------------------------------------------------
50 // 2D
51 //------------------------------------------------------------------------------
52 
53 //------------------------------------------------------------------------------
54 // E-vector -> single element
55 //------------------------------------------------------------------------------
56 template <int NUM_COMP, int P_1D>
57 inline __device__ void ReadElementStrided2d(SharedData_Cuda &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
58                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
59   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
60     const CeedInt node = data.t_id_x + data.t_id_y * P_1D;
61     const CeedInt ind  = node * strides_node + elem * strides_elem;
62     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
63       r_u[comp] = d_u[ind + comp * strides_comp];
64     }
65   }
66 }
67 
68 //------------------------------------------------------------------------------
69 // Single element -> E-vector
70 //------------------------------------------------------------------------------
71 template <int NUM_COMP, int P_1D>
72 inline __device__ void WriteElementStrided2d(SharedData_Cuda &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
73                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
74   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
75     const CeedInt node = data.t_id_x + data.t_id_y * P_1D;
76     const CeedInt ind  = node * strides_node + elem * strides_elem;
77     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
78       d_v[ind + comp * strides_comp] = r_v[comp];
79     }
80   }
81 }
82 
83 //------------------------------------------------------------------------------
84 // 3D
85 //------------------------------------------------------------------------------
86 
87 //------------------------------------------------------------------------------
88 // E-vector -> single element
89 //------------------------------------------------------------------------------
90 template <int NUM_COMP, int P_1D>
91 inline __device__ void ReadElementStrided3d(SharedData_Cuda &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
92                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
93   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
94     for (CeedInt z = 0; z < P_1D; z++) {
95       const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D;
96       const CeedInt ind  = node * strides_node + elem * strides_elem;
97       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
98         r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp];
99       }
100     }
101   }
102 }
103 
104 //------------------------------------------------------------------------------
105 // Single element -> E-vector
106 //------------------------------------------------------------------------------
107 template <int NUM_COMP, int P_1D>
108 inline __device__ void WriteElementStrided3d(SharedData_Cuda &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
109                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
110   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
111     for (CeedInt z = 0; z < P_1D; z++) {
112       const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D;
113       const CeedInt ind  = node * strides_node + elem * strides_elem;
114       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
115         d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D];
116       }
117     }
118   }
119 }
120 
121 //------------------------------------------------------------------------------
122 
123 #endif
124