xref: /libCEED/include/ceed/jit-source/hip/hip-shared-basis-read-write-templates.h (revision daadeac6547c0bce0e170b8a41c931051f52e9a3)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for HIP shared memory basis read/write templates
10 #ifndef _ceed_hip_shared_basis_read_write_templates_h
11 #define _ceed_hip_shared_basis_read_write_templates_h
12 
13 #include <ceed.h>
14 
15 //------------------------------------------------------------------------------
16 // Helper function: load matrices for basis actions
17 //------------------------------------------------------------------------------
18 template <int SIZE>
19 inline __device__ void loadMatrix(const CeedScalar *d_B, CeedScalar *B) {
20   CeedInt tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x;
21   for (CeedInt i = tid; i < SIZE; i += blockDim.x * blockDim.y * blockDim.z) B[i] = d_B[i];
22 }
23 
24 //------------------------------------------------------------------------------
25 // 1D
26 //------------------------------------------------------------------------------
27 
28 //------------------------------------------------------------------------------
29 // E-vector -> single element
30 //------------------------------------------------------------------------------
31 template <int NUM_COMP, int P_1D>
32 inline __device__ void ReadElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
33                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
34   if (data.t_id_x < P_1D) {
35     const CeedInt node = data.t_id_x;
36     const CeedInt ind  = node * strides_node + elem * strides_elem;
37     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
38       r_u[comp] = d_u[ind + comp * strides_comp];
39     }
40   }
41 }
42 
43 //------------------------------------------------------------------------------
44 // Single element -> E-vector
45 //------------------------------------------------------------------------------
46 template <int NUM_COMP, int P_1D>
47 inline __device__ void WriteElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
48                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
49   if (data.t_id_x < P_1D) {
50     const CeedInt node = data.t_id_x;
51     const CeedInt ind  = node * strides_node + elem * strides_elem;
52     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
53       d_v[ind + comp * strides_comp] = r_v[comp];
54     }
55   }
56 }
57 
58 //------------------------------------------------------------------------------
59 // 2D
60 //------------------------------------------------------------------------------
61 
62 //------------------------------------------------------------------------------
63 // E-vector -> single element
64 //------------------------------------------------------------------------------
65 template <int NUM_COMP, int P_1D>
66 inline __device__ void ReadElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
67                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
68   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
69     const CeedInt node = data.t_id_x + data.t_id_y * P_1D;
70     const CeedInt ind  = node * strides_node + elem * strides_elem;
71     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
72       r_u[comp] = d_u[ind + comp * strides_comp];
73     }
74   }
75 }
76 
77 //------------------------------------------------------------------------------
78 // Single element -> E-vector
79 //------------------------------------------------------------------------------
80 template <int NUM_COMP, int P_1D>
81 inline __device__ void WriteElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
82                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
83   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
84     const CeedInt node = data.t_id_x + data.t_id_y * P_1D;
85     const CeedInt ind  = node * strides_node + elem * strides_elem;
86     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
87       d_v[ind + comp * strides_comp] = r_v[comp];
88     }
89   }
90 }
91 
92 //------------------------------------------------------------------------------
93 // 3D
94 //------------------------------------------------------------------------------
95 
96 //------------------------------------------------------------------------------
97 // E-vector -> single element
98 //------------------------------------------------------------------------------
99 template <int NUM_COMP, int P_1D>
100 inline __device__ void ReadElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
101                                             const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
102   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
103     for (CeedInt z = 0; z < P_1D; z++) {
104       const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D;
105       const CeedInt ind  = node * strides_node + elem * strides_elem;
106       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
107         r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp];
108       }
109     }
110   }
111 }
112 
113 //------------------------------------------------------------------------------
114 // Single element -> E-vector
115 //------------------------------------------------------------------------------
116 template <int NUM_COMP, int P_1D>
117 inline __device__ void WriteElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp,
118                                              const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
119   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
120     for (CeedInt z = 0; z < P_1D; z++) {
121       const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D;
122       const CeedInt ind  = node * strides_node + elem * strides_elem;
123       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
124         d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D];
125       }
126     }
127   }
128 }
129 
130 //------------------------------------------------------------------------------
131 
132 #endif
133