1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 /// @file 9 /// Internal header for HIP shared memory basis read/write templates 10 #ifndef _ceed_hip_shared_basis_read_write_templates_h 11 #define _ceed_hip_shared_basis_read_write_templates_h 12 13 #include <ceed.h> 14 15 16 //------------------------------------------------------------------------------ 17 // Helper function: load matrices for basis actions 18 //------------------------------------------------------------------------------ 19 template <int SIZE> 20 inline __device__ void loadMatrix(const CeedScalar* d_B, CeedScalar* B) { 21 CeedInt tid = threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.y*blockDim.x; 22 for (CeedInt i = tid; i < SIZE; i += blockDim.x*blockDim.y*blockDim.z) 23 B[i] = d_B[i]; 24 } 25 26 //------------------------------------------------------------------------------ 27 // 1D 28 //------------------------------------------------------------------------------ 29 30 //------------------------------------------------------------------------------ 31 // E-vector -> single element 32 //------------------------------------------------------------------------------ 33 template <int NUM_COMP, int P_1D> 34 inline __device__ void ReadElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) { 35 if (data.t_id_x < P_1D) { 36 const CeedInt node = data.t_id_x; 37 const CeedInt ind = node * strides_node + elem * strides_elem; 38 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 39 r_u[comp] = d_u[ind + comp * strides_comp]; 40 } 41 } 42 } 43 44 //------------------------------------------------------------------------------ 45 // Single element -> E-vector 46 //------------------------------------------------------------------------------ 47 template <int NUM_COMP, int P_1D> 48 inline __device__ void WriteElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) { 49 if (data.t_id_x < P_1D) { 50 const CeedInt node = data.t_id_x; 51 const CeedInt ind = node * strides_node + elem * strides_elem; 52 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 53 d_v[ind + comp * strides_comp] = r_v[comp]; 54 } 55 } 56 } 57 58 //------------------------------------------------------------------------------ 59 // 2D 60 //------------------------------------------------------------------------------ 61 62 //------------------------------------------------------------------------------ 63 // E-vector -> single element 64 //------------------------------------------------------------------------------ 65 template <int NUM_COMP, int P_1D> 66 inline __device__ void ReadElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) { 67 if (data.t_id_x < P_1D && data.t_id_y < P_1D) { 68 const CeedInt node = data.t_id_x + data.t_id_y*P_1D; 69 const CeedInt ind = node * strides_node + elem * strides_elem; 70 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 71 r_u[comp] = d_u[ind + comp * strides_comp]; 72 } 73 } 74 } 75 76 //------------------------------------------------------------------------------ 77 // Single element -> E-vector 78 //------------------------------------------------------------------------------ 79 template <int NUM_COMP, int P_1D> 80 inline __device__ void WriteElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) { 81 if (data.t_id_x < P_1D && data.t_id_y < P_1D) { 82 const CeedInt node = data.t_id_x + data.t_id_y*P_1D; 83 const CeedInt ind = node * strides_node + elem * strides_elem; 84 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 85 d_v[ind + comp * strides_comp] = r_v[comp]; 86 } 87 } 88 } 89 90 //------------------------------------------------------------------------------ 91 // 3D 92 //------------------------------------------------------------------------------ 93 94 //------------------------------------------------------------------------------ 95 // E-vector -> single element 96 //------------------------------------------------------------------------------ 97 template <int NUM_COMP, int P_1D> 98 inline __device__ void ReadElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) { 99 if (data.t_id_x < P_1D && data.t_id_y < P_1D) { 100 for (CeedInt z = 0; z < P_1D; z++) { 101 const CeedInt node = data.t_id_x + data.t_id_y*P_1D + z*P_1D*P_1D; 102 const CeedInt ind = node * strides_node + elem * strides_elem; 103 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 104 r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp]; 105 } 106 } 107 } 108 } 109 110 //------------------------------------------------------------------------------ 111 // Single element -> E-vector 112 //------------------------------------------------------------------------------ 113 template <int NUM_COMP, int P_1D> 114 inline __device__ void WriteElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) { 115 if (data.t_id_x < P_1D && data.t_id_y < P_1D) { 116 for (CeedInt z = 0; z < P_1D; z++) { 117 const CeedInt node = data.t_id_x + data.t_id_y*P_1D + z*P_1D*P_1D; 118 const CeedInt ind = node * strides_node + elem * strides_elem; 119 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 120 d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D]; 121 } 122 } 123 } 124 } 125 126 //------------------------------------------------------------------------------ 127 128 #endif 129