xref: /libCEED/include/ceed/jit-source/hip/hip-shared-basis-read-write-templates.h (revision 672b0f2ac2d233f11bebf0085c50d29e53ac87eb)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for HIP shared memory basis read/write templates
10 #ifndef CEED_HIP_SHARED_BASIS_READ_WRITE_TEMPLATES_H
11 #define CEED_HIP_SHARED_BASIS_READ_WRITE_TEMPLATES_H
12 
13 #include <ceed.h>
14 
15 //------------------------------------------------------------------------------
16 // Helper function: load matrices for basis actions
17 //------------------------------------------------------------------------------
18 template <int SIZE>
19 inline __device__ void loadMatrix(const CeedScalar *d_B, CeedScalar *B) {
20   CeedInt tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
21 
22   for (CeedInt i = tid; i < SIZE; i += blockDim.x * blockDim.y * blockDim.z) B[i] = d_B[i];
23 }
24 
25 //------------------------------------------------------------------------------
26 // 1D
27 //------------------------------------------------------------------------------
28 
29 //------------------------------------------------------------------------------
30 // E-vector -> single element
31 //------------------------------------------------------------------------------
32 template <int NUM_COMP, int P_1D>
33 inline __device__ void ReadElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
34                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
35   if (data.t_id_x < P_1D) {
36     const CeedInt node = data.t_id_x;
37     const CeedInt ind  = node * strides_node + elem * strides_elem;
38 
39     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
40       r_u[comp] = d_u[ind + comp * strides_comp];
41     }
42   }
43 }
44 
45 //------------------------------------------------------------------------------
46 // Single element -> E-vector
47 //------------------------------------------------------------------------------
48 template <int NUM_COMP, int P_1D>
49 inline __device__ void WriteElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
50                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
51   if (data.t_id_x < P_1D) {
52     const CeedInt node = data.t_id_x;
53     const CeedInt ind  = node * strides_node + elem * strides_elem;
54 
55     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
56       d_v[ind + comp * strides_comp] = r_v[comp];
57     }
58   }
59 }
60 
61 //------------------------------------------------------------------------------
62 // 2D
63 //------------------------------------------------------------------------------
64 
65 //------------------------------------------------------------------------------
66 // E-vector -> single element
67 //------------------------------------------------------------------------------
68 template <int NUM_COMP, int P_1D>
69 inline __device__ void ReadElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
70                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
71   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
72     const CeedInt node = data.t_id_x + data.t_id_y * P_1D;
73     const CeedInt ind  = node * strides_node + elem * strides_elem;
74 
75     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
76       r_u[comp] = d_u[ind + comp * strides_comp];
77     }
78   }
79 }
80 
81 //------------------------------------------------------------------------------
82 // Single element -> E-vector
83 //------------------------------------------------------------------------------
84 template <int NUM_COMP, int P_1D>
85 inline __device__ void WriteElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
86                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
87   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
88     const CeedInt node = data.t_id_x + data.t_id_y * P_1D;
89     const CeedInt ind  = node * strides_node + elem * strides_elem;
90 
91     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
92       d_v[ind + comp * strides_comp] = r_v[comp];
93     }
94   }
95 }
96 
97 //------------------------------------------------------------------------------
98 // 3D
99 //------------------------------------------------------------------------------
100 
101 //------------------------------------------------------------------------------
102 // E-vector -> single element
103 //------------------------------------------------------------------------------
104 template <int NUM_COMP, int P_1D>
105 inline __device__ void ReadElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
106                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
107   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
108     for (CeedInt z = 0; z < P_1D; z++) {
109       const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D;
110       const CeedInt ind  = node * strides_node + elem * strides_elem;
111 
112       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
113         r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp];
114       }
115     }
116   }
117 }
118 
119 //------------------------------------------------------------------------------
120 // Single element -> E-vector
121 //------------------------------------------------------------------------------
122 template <int NUM_COMP, int P_1D>
123 inline __device__ void WriteElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
124                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
125   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
126     for (CeedInt z = 0; z < P_1D; z++) {
127       const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D;
128       const CeedInt ind  = node * strides_node + elem * strides_elem;
129 
130       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
131         d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D];
132       }
133     }
134   }
135 }
136 
137 //------------------------------------------------------------------------------
138 
139 #endif  // CEED_HIP_SHARED_BASIS_READ_WRITE_TEMPLATES_H
140