xref: /libCEED/include/ceed/jit-source/hip/hip-shared-basis-read-write-templates.h (revision 9dc0ea9a12d5a2dbb50983bee29c25b398979cc0)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for HIP shared memory basis read/write templates
10 
11 #include <ceed.h>
12 
13 //------------------------------------------------------------------------------
14 // Helper function: load matrices for basis actions
15 //------------------------------------------------------------------------------
16 template <int SIZE>
17 inline __device__ void loadMatrix(const CeedScalar *d_B, CeedScalar *B) {
18   CeedInt tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
19 
20   for (CeedInt i = tid; i < SIZE; i += blockDim.x * blockDim.y * blockDim.z) B[i] = d_B[i];
21 }
22 
23 //------------------------------------------------------------------------------
24 // 1D
25 //------------------------------------------------------------------------------
26 
27 //------------------------------------------------------------------------------
28 // E-vector -> single element
29 //------------------------------------------------------------------------------
30 template <int NUM_COMP, int P_1D>
31 inline __device__ void ReadElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
32                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
33   if (data.t_id_x < P_1D) {
34     const CeedInt node = data.t_id_x;
35     const CeedInt ind  = node * strides_node + elem * strides_elem;
36 
37     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
38       r_u[comp] = d_u[ind + comp * strides_comp];
39     }
40   }
41 }
42 
43 //------------------------------------------------------------------------------
44 // Single element -> E-vector
45 //------------------------------------------------------------------------------
46 template <int NUM_COMP, int P_1D>
47 inline __device__ void WriteElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
48                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
49   if (data.t_id_x < P_1D) {
50     const CeedInt node = data.t_id_x;
51     const CeedInt ind  = node * strides_node + elem * strides_elem;
52 
53     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
54       d_v[ind + comp * strides_comp] = r_v[comp];
55     }
56   }
57 }
58 
59 //------------------------------------------------------------------------------
60 // 2D
61 //------------------------------------------------------------------------------
62 
63 //------------------------------------------------------------------------------
64 // E-vector -> single element
65 //------------------------------------------------------------------------------
66 template <int NUM_COMP, int P_1D>
67 inline __device__ void ReadElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
68                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
69   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
70     const CeedInt node = data.t_id_x + data.t_id_y * P_1D;
71     const CeedInt ind  = node * strides_node + elem * strides_elem;
72 
73     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
74       r_u[comp] = d_u[ind + comp * strides_comp];
75     }
76   }
77 }
78 
79 //------------------------------------------------------------------------------
80 // Single element -> E-vector
81 //------------------------------------------------------------------------------
82 template <int NUM_COMP, int P_1D>
83 inline __device__ void WriteElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
84                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
85   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
86     const CeedInt node = data.t_id_x + data.t_id_y * P_1D;
87     const CeedInt ind  = node * strides_node + elem * strides_elem;
88 
89     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
90       d_v[ind + comp * strides_comp] = r_v[comp];
91     }
92   }
93 }
94 
95 //------------------------------------------------------------------------------
96 // 3D
97 //------------------------------------------------------------------------------
98 
99 //------------------------------------------------------------------------------
100 // E-vector -> single element
101 //------------------------------------------------------------------------------
102 template <int NUM_COMP, int P_1D>
103 inline __device__ void ReadElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
104                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
105   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
106     for (CeedInt z = 0; z < P_1D; z++) {
107       const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D;
108       const CeedInt ind  = node * strides_node + elem * strides_elem;
109 
110       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
111         r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp];
112       }
113     }
114   }
115 }
116 
117 //------------------------------------------------------------------------------
118 // Single element -> E-vector
119 //------------------------------------------------------------------------------
120 template <int NUM_COMP, int P_1D>
121 inline __device__ void WriteElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
122                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
123   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
124     for (CeedInt z = 0; z < P_1D; z++) {
125       const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D;
126       const CeedInt ind  = node * strides_node + elem * strides_elem;
127 
128       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
129         d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D];
130       }
131     }
132   }
133 }
134