1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 /// @file 9 /// Internal header for HIP shared memory basis read/write templates 10 #ifndef _ceed_hip_shared_basis_read_write_templates_h 11 #define _ceed_hip_shared_basis_read_write_templates_h 12 13 #include <ceed.h> 14 15 //------------------------------------------------------------------------------ 16 // Helper function: load matrices for basis actions 17 //------------------------------------------------------------------------------ 18 template <int SIZE> 19 inline __device__ void loadMatrix(const CeedScalar *d_B, CeedScalar *B) { 20 CeedInt tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.y * blockDim.x; 21 for (CeedInt i = tid; i < SIZE; i += blockDim.x * blockDim.y * blockDim.z) B[i] = d_B[i]; 22 } 23 24 //------------------------------------------------------------------------------ 25 // 1D 26 //------------------------------------------------------------------------------ 27 28 //------------------------------------------------------------------------------ 29 // E-vector -> single element 30 //------------------------------------------------------------------------------ 31 template <int NUM_COMP, int P_1D> 32 inline __device__ void ReadElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, 33 const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) { 34 if (data.t_id_x < P_1D) { 35 const CeedInt node = data.t_id_x; 36 const CeedInt ind = node * strides_node + elem * strides_elem; 37 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 38 r_u[comp] = d_u[ind + comp * strides_comp]; 39 } 40 } 41 } 42 43 //------------------------------------------------------------------------------ 44 // Single element -> E-vector 45 //------------------------------------------------------------------------------ 46 template <int NUM_COMP, int P_1D> 47 inline __device__ void WriteElementStrided1d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, 48 const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) { 49 if (data.t_id_x < P_1D) { 50 const CeedInt node = data.t_id_x; 51 const CeedInt ind = node * strides_node + elem * strides_elem; 52 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 53 d_v[ind + comp * strides_comp] = r_v[comp]; 54 } 55 } 56 } 57 58 //------------------------------------------------------------------------------ 59 // 2D 60 //------------------------------------------------------------------------------ 61 62 //------------------------------------------------------------------------------ 63 // E-vector -> single element 64 //------------------------------------------------------------------------------ 65 template <int NUM_COMP, int P_1D> 66 inline __device__ void ReadElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, 67 const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) { 68 if (data.t_id_x < P_1D && data.t_id_y < P_1D) { 69 const CeedInt node = data.t_id_x + data.t_id_y * P_1D; 70 const CeedInt ind = node * strides_node + elem * strides_elem; 71 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 72 r_u[comp] = d_u[ind + comp * strides_comp]; 73 } 74 } 75 } 76 77 //------------------------------------------------------------------------------ 78 // Single element -> E-vector 79 //------------------------------------------------------------------------------ 80 template <int NUM_COMP, int P_1D> 81 inline __device__ void WriteElementStrided2d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, 82 const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) { 83 if (data.t_id_x < P_1D && data.t_id_y < P_1D) { 84 const CeedInt node = data.t_id_x + data.t_id_y * P_1D; 85 const CeedInt ind = node * strides_node + elem * strides_elem; 86 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 87 d_v[ind + comp * strides_comp] = r_v[comp]; 88 } 89 } 90 } 91 92 //------------------------------------------------------------------------------ 93 // 3D 94 //------------------------------------------------------------------------------ 95 96 //------------------------------------------------------------------------------ 97 // E-vector -> single element 98 //------------------------------------------------------------------------------ 99 template <int NUM_COMP, int P_1D> 100 inline __device__ void ReadElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, 101 const CeedInt strides_elem, const CeedScalar *__restrict__ d_u, CeedScalar *r_u) { 102 if (data.t_id_x < P_1D && data.t_id_y < P_1D) { 103 for (CeedInt z = 0; z < P_1D; z++) { 104 const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D; 105 const CeedInt ind = node * strides_node + elem * strides_elem; 106 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 107 r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp]; 108 } 109 } 110 } 111 } 112 113 //------------------------------------------------------------------------------ 114 // Single element -> E-vector 115 //------------------------------------------------------------------------------ 116 template <int NUM_COMP, int P_1D> 117 inline __device__ void WriteElementStrided3d(SharedData_Hip &data, const CeedInt elem, const CeedInt strides_node, const CeedInt strides_comp, 118 const CeedInt strides_elem, const CeedScalar *r_v, CeedScalar *d_v) { 119 if (data.t_id_x < P_1D && data.t_id_y < P_1D) { 120 for (CeedInt z = 0; z < P_1D; z++) { 121 const CeedInt node = data.t_id_x + data.t_id_y * P_1D + z * P_1D * P_1D; 122 const CeedInt ind = node * strides_node + elem * strides_elem; 123 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 124 d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D]; 125 } 126 } 127 } 128 } 129 130 //------------------------------------------------------------------------------ 131 132 #endif 133