15aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3f80f4a74SSebastian Grimberg // 4f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5f80f4a74SSebastian Grimberg // 6f80f4a74SSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7f80f4a74SSebastian Grimberg 83c1e2affSSebastian Grimberg /// @file 93c1e2affSSebastian Grimberg /// Internal header for MAGMA backend common tensor basis definitions 10509d4af6SJeremy L Thompson #pragma once 11f80f4a74SSebastian Grimberg 123c1e2affSSebastian Grimberg #include "magma-common-defs.h" 13f80f4a74SSebastian Grimberg 149e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 15f80f4a74SSebastian Grimberg // read U or V of a 1D element into shared memory sU[][] or sV[][] -- for all components 16f80f4a74SSebastian Grimberg // the devptr is assumed to point directly to the element 17f80f4a74SSebastian Grimberg // must sync after call 183c1e2affSSebastian Grimberg template <typename T, int LENGTH, int NUM_COMP> 193c1e2affSSebastian Grimberg static __device__ __inline__ void read_1d(const T *devptr, const int compstride, T *sBuffer[NUM_COMP], const int tx) { 20f80f4a74SSebastian Grimberg if (tx < LENGTH) { 213c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 223c1e2affSSebastian Grimberg sBuffer[comp][tx] = devptr[comp * compstride + tx]; 23f80f4a74SSebastian Grimberg } 24f80f4a74SSebastian Grimberg } 25f80f4a74SSebastian Grimberg } 26f80f4a74SSebastian Grimberg 279e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 28f80f4a74SSebastian Grimberg // write V of a 1D element into global memory from sV[][] -- for all components 29f80f4a74SSebastian Grimberg // the devptr is assumed to point directly to the element 303c1e2affSSebastian Grimberg template <typename T, int LENGTH, int NUM_COMP> 313c1e2affSSebastian Grimberg static __device__ __inline__ void write_1d(T *sBuffer[NUM_COMP], T *devptr, const int compstride, const int tx) { 32f80f4a74SSebastian Grimberg if (tx < LENGTH) { 333c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 343c1e2affSSebastian Grimberg devptr[comp * compstride + tx] = sBuffer[comp][tx]; 35f80f4a74SSebastian Grimberg } 36f80f4a74SSebastian Grimberg } 37f80f4a74SSebastian Grimberg } 38f80f4a74SSebastian Grimberg 399e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 40*db2becc9SJeremy L Thompson // sum into V of a 1D element into global memory from sV[][] -- for all components 41*db2becc9SJeremy L Thompson // the devptr is assumed to point directly to the element 42*db2becc9SJeremy L Thompson template <typename T, int LENGTH, int NUM_COMP> 43*db2becc9SJeremy L Thompson static __device__ __inline__ void sum_1d(T *sBuffer[NUM_COMP], T *devptr, const int compstride, const int tx) { 44*db2becc9SJeremy L Thompson if (tx < LENGTH) { 45*db2becc9SJeremy L Thompson for (int comp = 0; comp < NUM_COMP; comp++) { 46*db2becc9SJeremy L Thompson devptr[comp * compstride + tx] += sBuffer[comp][tx]; 47*db2becc9SJeremy L Thompson } 48*db2becc9SJeremy L Thompson } 49*db2becc9SJeremy L Thompson } 50*db2becc9SJeremy L Thompson 51*db2becc9SJeremy L Thompson //////////////////////////////////////////////////////////////////////////////// 52f80f4a74SSebastian Grimberg // read U of a 2D element into registers rU[][][] -- for all components of a single dim 53f80f4a74SSebastian Grimberg // dU is assumed to be offset by elem-stride and dim-stride 543c1e2affSSebastian Grimberg // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE] 553c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rU 569e0c01faSSebastian Grimberg // rU_SIZE can be different from P (e.g. max(P, Q)) 573c1e2affSSebastian Grimberg // sTmp is a shared memory workspace of size P^2 583c1e2affSSebastian Grimberg template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM> 599e0c01faSSebastian Grimberg static __device__ __inline__ void read_U_2d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) { 609e0c01faSSebastian Grimberg // read U as a batch P of (1 x P) vectors 613c1e2affSSebastian Grimberg // vec 0 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 623c1e2affSSebastian Grimberg // vec 1 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 63f80f4a74SSebastian Grimberg // ... 643c1e2affSSebastian Grimberg // vec P-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 65f80f4a74SSebastian Grimberg // threads collaboratively read vec0 and then vec1 and so on 66f80f4a74SSebastian Grimberg // but for the kernel, we want 67f80f4a74SSebastian Grimberg // thread 0 to hold all of vec0 in registers, and 68f80f4a74SSebastian Grimberg // thread 1 to hold all of vec1 in registers, and and so on 69f80f4a74SSebastian Grimberg // so we need to transpose 703c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 71f80f4a74SSebastian Grimberg // read from global memory into shared memory 723c1e2affSSebastian Grimberg if (tx < P) { 733c1e2affSSebastian Grimberg for (int i = 0; i < P; i++) { 743c1e2affSSebastian Grimberg sTmp[i * P + tx] = dU[comp * compstride + i * P + tx]; 75f80f4a74SSebastian Grimberg } 76f80f4a74SSebastian Grimberg } 77f80f4a74SSebastian Grimberg __syncthreads(); 78f80f4a74SSebastian Grimberg 793c1e2affSSebastian Grimberg if (tx < P) { 803c1e2affSSebastian Grimberg for (int i = 0; i < P; i++) { 813c1e2affSSebastian Grimberg rU[i_DIM][comp][i] = sTmp[tx * P + i]; 82f80f4a74SSebastian Grimberg } 83f80f4a74SSebastian Grimberg } 84f80f4a74SSebastian Grimberg __syncthreads(); 85f80f4a74SSebastian Grimberg } 86f80f4a74SSebastian Grimberg } 87f80f4a74SSebastian Grimberg 889e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 89f80f4a74SSebastian Grimberg // read V of a 2D element into registers rV[][][] -- for all components of a single dim 90f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride 913c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 923c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rV 939e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q)) 943c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 959e0c01faSSebastian Grimberg static __device__ __inline__ void read_V_2d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 963c1e2affSSebastian Grimberg if (tx < Q) { 973c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 983c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) { 993c1e2affSSebastian Grimberg rV[i_DIM][comp][j] = dV[comp * compstride + j * Q + tx]; 100f80f4a74SSebastian Grimberg } 101f80f4a74SSebastian Grimberg } 102f80f4a74SSebastian Grimberg } 103f80f4a74SSebastian Grimberg } 104f80f4a74SSebastian Grimberg 1059e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 106f80f4a74SSebastian Grimberg // write V of a 2D element from registers rV[][][] to global memory -- for all components of a single dim 107f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride 1083c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 1099e0c01faSSebastian Grimberg // i_DIM specifies which dimension is being written to in dV 1109e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q)) 1113c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 1129e0c01faSSebastian Grimberg static __device__ __inline__ void write_V_2d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 1133c1e2affSSebastian Grimberg if (tx < Q) { 1143c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 1153c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) { 1163c1e2affSSebastian Grimberg dV[comp * compstride + j * Q + tx] = rV[i_DIM][comp][j]; 117f80f4a74SSebastian Grimberg } 118f80f4a74SSebastian Grimberg } 119f80f4a74SSebastian Grimberg } 120f80f4a74SSebastian Grimberg } 121f80f4a74SSebastian Grimberg 1229e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 123*db2becc9SJeremy L Thompson // sum into V of a 2D element from registers rV[][][] to global memory -- for all components of a single dim 124*db2becc9SJeremy L Thompson // dV is assumed to be offset by elem-stride and dim-stride 125*db2becc9SJeremy L Thompson // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 126*db2becc9SJeremy L Thompson // i_DIM specifies which dimension is being written to in dV 127*db2becc9SJeremy L Thompson // rV_SIZE can be different from P (e.g. max(P, Q)) 128*db2becc9SJeremy L Thompson template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 129*db2becc9SJeremy L Thompson static __device__ __inline__ void sum_V_2d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 130*db2becc9SJeremy L Thompson if (tx < Q) { 131*db2becc9SJeremy L Thompson for (int comp = 0; comp < NUM_COMP; comp++) { 132*db2becc9SJeremy L Thompson for (int j = 0; j < Q; j++) { 133*db2becc9SJeremy L Thompson dV[comp * compstride + j * Q + tx] += rV[i_DIM][comp][j]; 134*db2becc9SJeremy L Thompson } 135*db2becc9SJeremy L Thompson } 136*db2becc9SJeremy L Thompson } 137*db2becc9SJeremy L Thompson } 138*db2becc9SJeremy L Thompson 139*db2becc9SJeremy L Thompson //////////////////////////////////////////////////////////////////////////////// 140f80f4a74SSebastian Grimberg // read U of a 3D element into registers rU[][][] -- for all components of a single dim 141f80f4a74SSebastian Grimberg // dU is assumed to be offset by elem-stride and dim-stride 1423c1e2affSSebastian Grimberg // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE] 1433c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rU 1449e0c01faSSebastian Grimberg // rU_SIZE can be different from P (e.g. max(P, Q)) 1453c1e2affSSebastian Grimberg // sTmp is a shared memory workspace of size P^3 1463c1e2affSSebastian Grimberg template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM> 1479e0c01faSSebastian Grimberg static __device__ __inline__ void read_U_3d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) { 1483c1e2affSSebastian Grimberg // read U as a batch P^2 of (1 x P_) vectors 1493c1e2affSSebastian Grimberg // vec 0 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 1503c1e2affSSebastian Grimberg // vec 1 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 151f80f4a74SSebastian Grimberg // ... 1523c1e2affSSebastian Grimberg // vec P^2-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 153f80f4a74SSebastian Grimberg // threads collaboratively read vec0 and then vec1 and so on 154f80f4a74SSebastian Grimberg // but for the kernel, we want 155f80f4a74SSebastian Grimberg // thread 0 to hold all of vec0 in registers, and 156f80f4a74SSebastian Grimberg // thread 1 to hold all of vec1 in registers, and and so on 157f80f4a74SSebastian Grimberg // so we need to transpose 1583c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 159f80f4a74SSebastian Grimberg // read from global memory into shared memory 1603c1e2affSSebastian Grimberg if (tx < P * P) { 1613c1e2affSSebastian Grimberg for (int i = 0; i < P; i++) { 1623c1e2affSSebastian Grimberg sTmp[i * P * P + tx] = dU[comp * compstride + i * P * P + tx]; 163f80f4a74SSebastian Grimberg } 164f80f4a74SSebastian Grimberg } 165f80f4a74SSebastian Grimberg __syncthreads(); 166f80f4a74SSebastian Grimberg 1673c1e2affSSebastian Grimberg if (tx < P * P) { 1683c1e2affSSebastian Grimberg for (int i = 0; i < P; i++) { 1693c1e2affSSebastian Grimberg rU[i_DIM][comp][i] = sTmp[tx * P + i]; 170f80f4a74SSebastian Grimberg } 171f80f4a74SSebastian Grimberg } 172f80f4a74SSebastian Grimberg __syncthreads(); 173f80f4a74SSebastian Grimberg } 174f80f4a74SSebastian Grimberg } 175f80f4a74SSebastian Grimberg 1769e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 177f80f4a74SSebastian Grimberg // read V of a 3D element into registers rV[][][] -- for all components of a single dim 178f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride 1793c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 1803c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rV 1819e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q)) 1823c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 1839e0c01faSSebastian Grimberg static __device__ __inline__ void read_V_3d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 1843c1e2affSSebastian Grimberg if (tx < Q * Q) { 1853c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 1863c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) { 1873c1e2affSSebastian Grimberg rV[i_DIM][comp][j] = dV[comp * compstride + j * (Q * Q) + tx]; 188f80f4a74SSebastian Grimberg } 189f80f4a74SSebastian Grimberg } 190f80f4a74SSebastian Grimberg } 191f80f4a74SSebastian Grimberg } 192f80f4a74SSebastian Grimberg 1939e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 194f80f4a74SSebastian Grimberg // write V of a 3D element from registers rV[][][] to global memory -- for all components of a single dim 195f80f4a74SSebastian Grimberg // dV is assumed to point directly to the element (i.e. already offset by elem-stride) 1963c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 1979e0c01faSSebastian Grimberg // i_DIM specifies which dimension is being written to in dV 1989e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q)) 1993c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 2009e0c01faSSebastian Grimberg static __device__ __inline__ void write_V_3d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 2013c1e2affSSebastian Grimberg if (tx < (Q * Q)) { 2023c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) { 2033c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) { 2043c1e2affSSebastian Grimberg dV[comp * compstride + j * (Q * Q) + tx] = rV[i_DIM][comp][j]; 205f80f4a74SSebastian Grimberg } 206f80f4a74SSebastian Grimberg } 207f80f4a74SSebastian Grimberg } 208f80f4a74SSebastian Grimberg } 209f80f4a74SSebastian Grimberg 2109e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 211*db2becc9SJeremy L Thompson // sum into V of a 3D element from registers rV[][][] to global memory -- for all components of a single dim 212*db2becc9SJeremy L Thompson // dV is assumed to point directly to the element (i.e. already offset by elem-stride) 213*db2becc9SJeremy L Thompson // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 214*db2becc9SJeremy L Thompson // i_DIM specifies which dimension is being written to in dV 215*db2becc9SJeremy L Thompson // rV_SIZE can be different from P (e.g. max(P, Q)) 216*db2becc9SJeremy L Thompson template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 217*db2becc9SJeremy L Thompson static __device__ __inline__ void sum_V_3d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 218*db2becc9SJeremy L Thompson if (tx < (Q * Q)) { 219*db2becc9SJeremy L Thompson for (int comp = 0; comp < NUM_COMP; comp++) { 220*db2becc9SJeremy L Thompson for (int j = 0; j < Q; j++) { 221*db2becc9SJeremy L Thompson dV[comp * compstride + j * (Q * Q) + tx] += rV[i_DIM][comp][j]; 222*db2becc9SJeremy L Thompson } 223*db2becc9SJeremy L Thompson } 224*db2becc9SJeremy L Thompson } 225*db2becc9SJeremy L Thompson } 226*db2becc9SJeremy L Thompson 227*db2becc9SJeremy L Thompson //////////////////////////////////////////////////////////////////////////////// 2289e0c01faSSebastian Grimberg // reads T (no-trans) into shared memory 2299e0c01faSSebastian Grimberg // T is B x J 230f80f4a74SSebastian Grimberg // must sync after call 231f80f4a74SSebastian Grimberg template <int B, int J> 2329e0c01faSSebastian Grimberg static __device__ __inline__ void read_T_notrans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) { 233f80f4a74SSebastian Grimberg if (tx < B) { 234f80f4a74SSebastian Grimberg for (int i = 0; i < J; i++) { 235f80f4a74SSebastian Grimberg sT[i * B + tx] = dT[i * B + tx]; 236f80f4a74SSebastian Grimberg } 237f80f4a74SSebastian Grimberg } 2389e0c01faSSebastian Grimberg // must sync after call 2399e0c01faSSebastian Grimberg } 2409e0c01faSSebastian Grimberg 2419e0c01faSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 2429e0c01faSSebastian Grimberg // reads T (trans) into shared memory 243f80f4a74SSebastian Grimberg // T is J x B 2449e0c01faSSebastian Grimberg // must sync after call 2459e0c01faSSebastian Grimberg template <int B, int J> 2469e0c01faSSebastian Grimberg static __device__ __inline__ void read_T_trans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) { 247f80f4a74SSebastian Grimberg if (tx < J) { 248f80f4a74SSebastian Grimberg for (int i = 0; i < B; i++) { 249f80f4a74SSebastian Grimberg sT[tx * B + i] = dT[i * J + tx]; 250f80f4a74SSebastian Grimberg } 251f80f4a74SSebastian Grimberg } 252f80f4a74SSebastian Grimberg // must sync after call 253f80f4a74SSebastian Grimberg } 254