xref: /libCEED/include/ceed/jit-source/hip/hip-shared-basis-read-write-templates.h (revision 947f93aa7135eb1759bf2866bd2fbd481436b113)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for HIP shared memory basis read/write templates
10 #ifndef _ceed_hip_shared_basis_read_write_templates_h
11 #define _ceed_hip_shared_basis_read_write_templates_h
12 
13 #include <ceed.h>
14 
15 
16 //------------------------------------------------------------------------------
17 // Helper function: load matrices for basis actions
18 //------------------------------------------------------------------------------
19 template <int SIZE>
20 inline __device__ void loadMatrix(const CeedScalar* d_B, CeedScalar* B) {
21   CeedInt tid = threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.y*blockDim.x;
22   for (CeedInt i = tid; i < SIZE; i += blockDim.x*blockDim.y*blockDim.z)
23     B[i] = d_B[i];
24 }
25 
26 //------------------------------------------------------------------------------
27 // 1D
28 //------------------------------------------------------------------------------
29 
30 //------------------------------------------------------------------------------
31 // E-vector -> single element
32 //------------------------------------------------------------------------------
33 template <int NUM_COMP, int P_1D>
34 inline __device__ void ReadElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
35   if (data.t_id_x < P_1D) {
36     const CeedInt node = data.t_id_x;
37     const CeedInt ind = node * strides_node + elem * strides_elem;
38     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
39       r_u[comp] = d_u[ind + comp * strides_comp];
40     }
41   }
42 }
43 
44 //------------------------------------------------------------------------------
45 // Single element -> E-vector
46 //------------------------------------------------------------------------------
47 template <int NUM_COMP, int P_1D>
48 inline __device__ void WriteElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
49   if (data.t_id_x < P_1D) {
50     const CeedInt node = data.t_id_x;
51     const CeedInt ind = node * strides_node + elem * strides_elem;
52     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
53       d_v[ind + comp * strides_comp] = r_v[comp];
54     }
55   }
56 }
57 
58 //------------------------------------------------------------------------------
59 // 2D
60 //------------------------------------------------------------------------------
61 
62 //------------------------------------------------------------------------------
63 // E-vector -> single element
64 //------------------------------------------------------------------------------
65 template <int NUM_COMP, int P_1D>
66 inline __device__ void ReadElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
67   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
68     const CeedInt node = data.t_id_x + data.t_id_y*P_1D;
69     const CeedInt ind = node * strides_node + elem * strides_elem;
70     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
71       r_u[comp] = d_u[ind + comp * strides_comp];
72     }
73   }
74 }
75 
76 //------------------------------------------------------------------------------
77 // Single element -> E-vector
78 //------------------------------------------------------------------------------
79 template <int NUM_COMP, int P_1D>
80 inline __device__ void WriteElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
81   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
82     const CeedInt node = data.t_id_x + data.t_id_y*P_1D;
83     const CeedInt ind = node * strides_node + elem * strides_elem;
84     for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
85       d_v[ind + comp * strides_comp] = r_v[comp];
86     }
87   }
88 }
89 
90 //------------------------------------------------------------------------------
91 // 3D
92 //------------------------------------------------------------------------------
93 
94 //------------------------------------------------------------------------------
95 // E-vector -> single element
96 //------------------------------------------------------------------------------
97 template <int NUM_COMP, int P_1D>
98 inline __device__ void ReadElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) {
99   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
100     for (CeedInt z = 0; z < P_1D; z++) {
101       const CeedInt node = data.t_id_x + data.t_id_y*P_1D + z*P_1D*P_1D;
102       const CeedInt ind = node * strides_node + elem * strides_elem;
103       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
104         r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp];
105       }
106     }
107   }
108 }
109 
110 //------------------------------------------------------------------------------
111 // Single element -> E-vector
112 //------------------------------------------------------------------------------
113 template <int NUM_COMP, int P_1D>
114 inline __device__ void WriteElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) {
115   if (data.t_id_x < P_1D && data.t_id_y < P_1D) {
116     for (CeedInt z = 0; z < P_1D; z++) {
117       const CeedInt node = data.t_id_x + data.t_id_y*P_1D + z*P_1D*P_1D;
118       const CeedInt ind = node * strides_node + elem * strides_elem;
119       for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
120         d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D];
121       }
122     }
123   }
124 }
125 
126 //------------------------------------------------------------------------------
127 
128 #endif
129