1*5aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3f80f4a74SSebastian Grimberg // 4f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5f80f4a74SSebastian Grimberg // 6f80f4a74SSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7f80f4a74SSebastian Grimberg 83c1e2affSSebastian Grimberg /// @file 93c1e2affSSebastian Grimberg /// Internal header for MAGMA backend common tensor basis definitions 10f80f4a74SSebastian Grimberg #ifndef CEED_MAGMA_COMMON_TENSOR_H 11f80f4a74SSebastian Grimberg #define CEED_MAGMA_COMMON_TENSOR_H 12f80f4a74SSebastian Grimberg 133c1e2affSSebastian Grimberg #include "magma-common-defs.h" 14f80f4a74SSebastian Grimberg 159e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 16f80f4a74SSebastian Grimberg // read U or V of a 1D element into shared memory sU[][] or sV[][] -- for all components 17f80f4a74SSebastian Grimberg // the devptr is assumed to point directly to the element 18f80f4a74SSebastian Grimberg // must sync after call 193c1e2affSSebastian Grimberg template <typename T, int LENGTH, int NUM_COMP> 203c1e2affSSebastian Grimberg static __device__ __inline__ void read_1d(const T *devptr, const int compstride, T *sBuffer[NUM_COMP], const int tx) { 21f80f4a74SSebastian Grimberg if (tx < LENGTH) { 223c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 233c1e2affSSebastian Grimberg sBuffer[comp][tx] = devptr[comp * compstride + tx]; 24f80f4a74SSebastian Grimberg } 25f80f4a74SSebastian Grimberg } 26f80f4a74SSebastian Grimberg } 27f80f4a74SSebastian Grimberg 289e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 29f80f4a74SSebastian Grimberg // write V of a 1D element into global memory from sV[][] -- for all components 30f80f4a74SSebastian Grimberg // the devptr is assumed to point directly to the element 313c1e2affSSebastian Grimberg template <typename T, int LENGTH, int NUM_COMP> 323c1e2affSSebastian Grimberg static __device__ __inline__ void write_1d(T *sBuffer[NUM_COMP], T *devptr, const int compstride, const int tx) { 33f80f4a74SSebastian Grimberg if (tx < LENGTH) { 343c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 353c1e2affSSebastian Grimberg devptr[comp * compstride + tx] = sBuffer[comp][tx]; 36f80f4a74SSebastian Grimberg } 37f80f4a74SSebastian Grimberg } 38f80f4a74SSebastian Grimberg } 39f80f4a74SSebastian Grimberg 409e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 41f80f4a74SSebastian Grimberg // read U of a 2D element into registers rU[][][] -- for all components of a single dim 42f80f4a74SSebastian Grimberg // dU is assumed to be offset by elem-stride and dim-stride 433c1e2affSSebastian Grimberg // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE] 443c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rU 459e0c01faSSebastian Grimberg // rU_SIZE can be different from P (e.g. max(P, Q)) 463c1e2affSSebastian Grimberg // sTmp is a shared memory workspace of size P^2 473c1e2affSSebastian Grimberg template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM> 489e0c01faSSebastian Grimberg static __device__ __inline__ void read_U_2d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) { 499e0c01faSSebastian Grimberg // read U as a batch P of (1 x P) vectors 503c1e2affSSebastian Grimberg // vec 0 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 513c1e2affSSebastian Grimberg // vec 1 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 52f80f4a74SSebastian Grimberg // ... 533c1e2affSSebastian Grimberg // vec P-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 54f80f4a74SSebastian Grimberg // threads collaboratively read vec0 and then vec1 and so on 55f80f4a74SSebastian Grimberg // but for the kernel, we want 56f80f4a74SSebastian Grimberg // thread 0 to hold all of vec0 in registers, and 57f80f4a74SSebastian Grimberg // thread 1 to hold all of vec1 in registers, and and so on 58f80f4a74SSebastian Grimberg // so we need to transpose 593c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 60f80f4a74SSebastian Grimberg // read from global memory into shared memory 613c1e2affSSebastian Grimberg if (tx < P) { 623c1e2affSSebastian Grimberg for (int i = 0; i < P; i++) { 633c1e2affSSebastian Grimberg sTmp[i * P + tx] = dU[comp * compstride + i * P + tx]; 64f80f4a74SSebastian Grimberg } 65f80f4a74SSebastian Grimberg } 66f80f4a74SSebastian Grimberg __syncthreads(); 67f80f4a74SSebastian Grimberg 683c1e2affSSebastian Grimberg if (tx < P) { 693c1e2affSSebastian Grimberg for (int i = 0; i < P; i++) { 703c1e2affSSebastian Grimberg rU[i_DIM][comp][i] = sTmp[tx * P + i]; 71f80f4a74SSebastian Grimberg } 72f80f4a74SSebastian Grimberg } 73f80f4a74SSebastian Grimberg __syncthreads(); 74f80f4a74SSebastian Grimberg } 75f80f4a74SSebastian Grimberg } 76f80f4a74SSebastian Grimberg 779e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 78f80f4a74SSebastian Grimberg // read V of a 2D element into registers rV[][][] -- for all components of a single dim 79f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride 803c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 813c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rV 829e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q)) 833c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 849e0c01faSSebastian Grimberg static __device__ __inline__ void read_V_2d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 853c1e2affSSebastian Grimberg if (tx < Q) { 863c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 873c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) { 883c1e2affSSebastian Grimberg rV[i_DIM][comp][j] = dV[comp * compstride + j * Q + tx]; 89f80f4a74SSebastian Grimberg } 90f80f4a74SSebastian Grimberg } 91f80f4a74SSebastian Grimberg } 92f80f4a74SSebastian Grimberg } 93f80f4a74SSebastian Grimberg 949e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 95f80f4a74SSebastian Grimberg // write V of a 2D element from registers rV[][][] to global memory -- for all components of a single dim 96f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride 973c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 989e0c01faSSebastian Grimberg // i_DIM specifies which dimension is being written to in dV 999e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q)) 1003c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 1019e0c01faSSebastian Grimberg static __device__ __inline__ void write_V_2d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 1023c1e2affSSebastian Grimberg if (tx < Q) { 1033c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 1043c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) { 1053c1e2affSSebastian Grimberg dV[comp * compstride + j * Q + tx] = rV[i_DIM][comp][j]; 106f80f4a74SSebastian Grimberg } 107f80f4a74SSebastian Grimberg } 108f80f4a74SSebastian Grimberg } 109f80f4a74SSebastian Grimberg } 110f80f4a74SSebastian Grimberg 1119e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 112f80f4a74SSebastian Grimberg // read U of a 3D element into registers rU[][][] -- for all components of a single dim 113f80f4a74SSebastian Grimberg // dU is assumed to be offset by elem-stride and dim-stride 1143c1e2affSSebastian Grimberg // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE] 1153c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rU 1169e0c01faSSebastian Grimberg // rU_SIZE can be different from P (e.g. max(P, Q)) 1173c1e2affSSebastian Grimberg // sTmp is a shared memory workspace of size P^3 1183c1e2affSSebastian Grimberg template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM> 1199e0c01faSSebastian Grimberg static __device__ __inline__ void read_U_3d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) { 1203c1e2affSSebastian Grimberg // read U as a batch P^2 of (1 x P_) vectors 1213c1e2affSSebastian Grimberg // vec 0 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 1223c1e2affSSebastian Grimberg // vec 1 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 123f80f4a74SSebastian Grimberg // ... 1243c1e2affSSebastian Grimberg // vec P^2-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 125f80f4a74SSebastian Grimberg // threads collaboratively read vec0 and then vec1 and so on 126f80f4a74SSebastian Grimberg // but for the kernel, we want 127f80f4a74SSebastian Grimberg // thread 0 to hold all of vec0 in registers, and 128f80f4a74SSebastian Grimberg // thread 1 to hold all of vec1 in registers, and and so on 129f80f4a74SSebastian Grimberg // so we need to transpose 1303c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 131f80f4a74SSebastian Grimberg // read from global memory into shared memory 1323c1e2affSSebastian Grimberg if (tx < P * P) { 1333c1e2affSSebastian Grimberg for (int i = 0; i < P; i++) { 1343c1e2affSSebastian Grimberg sTmp[i * P * P + tx] = dU[comp * compstride + i * P * P + tx]; 135f80f4a74SSebastian Grimberg } 136f80f4a74SSebastian Grimberg } 137f80f4a74SSebastian Grimberg __syncthreads(); 138f80f4a74SSebastian Grimberg 1393c1e2affSSebastian Grimberg if (tx < P * P) { 1403c1e2affSSebastian Grimberg for (int i = 0; i < P; i++) { 1413c1e2affSSebastian Grimberg rU[i_DIM][comp][i] = sTmp[tx * P + i]; 142f80f4a74SSebastian Grimberg } 143f80f4a74SSebastian Grimberg } 144f80f4a74SSebastian Grimberg __syncthreads(); 145f80f4a74SSebastian Grimberg } 146f80f4a74SSebastian Grimberg } 147f80f4a74SSebastian Grimberg 1489e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 149f80f4a74SSebastian Grimberg // read V of a 3D element into registers rV[][][] -- for all components of a single dim 150f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride 1513c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 1523c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rV 1539e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q)) 1543c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 1559e0c01faSSebastian Grimberg static __device__ __inline__ void read_V_3d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 1563c1e2affSSebastian Grimberg if (tx < Q * Q) { 1573c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 1583c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) { 1593c1e2affSSebastian Grimberg rV[i_DIM][comp][j] = dV[comp * compstride + j * (Q * Q) + tx]; 160f80f4a74SSebastian Grimberg } 161f80f4a74SSebastian Grimberg } 162f80f4a74SSebastian Grimberg } 163f80f4a74SSebastian Grimberg } 164f80f4a74SSebastian Grimberg 1659e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 166f80f4a74SSebastian Grimberg // write V of a 3D element from registers rV[][][] to global memory -- for all components of a single dim 167f80f4a74SSebastian Grimberg // dV is assumed to point directly to the element (i.e. already offset by elem-stride) 1683c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 1699e0c01faSSebastian Grimberg // i_DIM specifies which dimension is being written to in dV 1709e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q)) 1713c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 1729e0c01faSSebastian Grimberg static __device__ __inline__ void write_V_3d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 1733c1e2affSSebastian Grimberg if (tx < (Q * Q)) { 1743c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 1753c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) { 1763c1e2affSSebastian Grimberg dV[comp * compstride + j * (Q * Q) + tx] = rV[i_DIM][comp][j]; 177f80f4a74SSebastian Grimberg } 178f80f4a74SSebastian Grimberg } 179f80f4a74SSebastian Grimberg } 180f80f4a74SSebastian Grimberg } 181f80f4a74SSebastian Grimberg 1829e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 1839e0c01faSSebastian Grimberg // reads T (no-trans) into shared memory 1849e0c01faSSebastian Grimberg // T is B x J 185f80f4a74SSebastian Grimberg // must sync after call 186f80f4a74SSebastian Grimberg template <int B, int J> 1879e0c01faSSebastian Grimberg static __device__ __inline__ void read_T_notrans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) { 188f80f4a74SSebastian Grimberg if (tx < B) { 189f80f4a74SSebastian Grimberg for (int i = 0; i < J; i++) { 190f80f4a74SSebastian Grimberg sT[i * B + tx] = dT[i * B + tx]; 191f80f4a74SSebastian Grimberg } 192f80f4a74SSebastian Grimberg } 1939e0c01faSSebastian Grimberg // must sync after call 1949e0c01faSSebastian Grimberg } 1959e0c01faSSebastian Grimberg 1969e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 1979e0c01faSSebastian Grimberg // reads T (trans) into shared memory 198f80f4a74SSebastian Grimberg // T is J x B 1999e0c01faSSebastian Grimberg // must sync after call 2009e0c01faSSebastian Grimberg template <int B, int J> 2019e0c01faSSebastian Grimberg static __device__ __inline__ void read_T_trans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) { 202f80f4a74SSebastian Grimberg if (tx < J) { 203f80f4a74SSebastian Grimberg for (int i = 0; i < B; i++) { 204f80f4a74SSebastian Grimberg sT[tx * B + i] = dT[i * J + tx]; 205f80f4a74SSebastian Grimberg } 206f80f4a74SSebastian Grimberg } 207f80f4a74SSebastian Grimberg // must sync after call 208f80f4a74SSebastian Grimberg } 209f80f4a74SSebastian Grimberg 210f80f4a74SSebastian Grimberg #endif // CEED_MAGMA_COMMON_TENSOR_H 211