1*9ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3f80f4a74SSebastian Grimberg //
4f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5f80f4a74SSebastian Grimberg //
6f80f4a74SSebastian Grimberg // This file is part of CEED: http://github.com/ceed
7f80f4a74SSebastian Grimberg
83c1e2affSSebastian Grimberg /// @file
93c1e2affSSebastian Grimberg /// Internal header for MAGMA backend common tensor basis definitions
10509d4af6SJeremy L Thompson #pragma once
11f80f4a74SSebastian Grimberg
123c1e2affSSebastian Grimberg #include "magma-common-defs.h"
13f80f4a74SSebastian Grimberg
149e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
15f80f4a74SSebastian Grimberg // read U or V of a 1D element into shared memory sU[][] or sV[][] -- for all components
16f80f4a74SSebastian Grimberg // the devptr is assumed to point directly to the element
17f80f4a74SSebastian Grimberg // must sync after call
183c1e2affSSebastian Grimberg template <typename T, int LENGTH, int NUM_COMP>
read_1d(const T * devptr,const int compstride,T * sBuffer[NUM_COMP],const int tx)193c1e2affSSebastian Grimberg static __device__ __inline__ void read_1d(const T *devptr, const int compstride, T *sBuffer[NUM_COMP], const int tx) {
20f80f4a74SSebastian Grimberg if (tx < LENGTH) {
213c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) {
223c1e2affSSebastian Grimberg sBuffer[comp][tx] = devptr[comp * compstride + tx];
23f80f4a74SSebastian Grimberg }
24f80f4a74SSebastian Grimberg }
25f80f4a74SSebastian Grimberg }
26f80f4a74SSebastian Grimberg
279e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
28f80f4a74SSebastian Grimberg // write V of a 1D element into global memory from sV[][] -- for all components
29f80f4a74SSebastian Grimberg // the devptr is assumed to point directly to the element
303c1e2affSSebastian Grimberg template <typename T, int LENGTH, int NUM_COMP>
write_1d(T * sBuffer[NUM_COMP],T * devptr,const int compstride,const int tx)313c1e2affSSebastian Grimberg static __device__ __inline__ void write_1d(T *sBuffer[NUM_COMP], T *devptr, const int compstride, const int tx) {
32f80f4a74SSebastian Grimberg if (tx < LENGTH) {
333c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) {
343c1e2affSSebastian Grimberg devptr[comp * compstride + tx] = sBuffer[comp][tx];
35f80f4a74SSebastian Grimberg }
36f80f4a74SSebastian Grimberg }
37f80f4a74SSebastian Grimberg }
38f80f4a74SSebastian Grimberg
399e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
40db2becc9SJeremy L Thompson // sum into V of a 1D element into global memory from sV[][] -- for all components
41db2becc9SJeremy L Thompson // the devptr is assumed to point directly to the element
42db2becc9SJeremy L Thompson template <typename T, int LENGTH, int NUM_COMP>
sum_1d(T * sBuffer[NUM_COMP],T * devptr,const int compstride,const int tx)43db2becc9SJeremy L Thompson static __device__ __inline__ void sum_1d(T *sBuffer[NUM_COMP], T *devptr, const int compstride, const int tx) {
44db2becc9SJeremy L Thompson if (tx < LENGTH) {
45db2becc9SJeremy L Thompson for (int comp = 0; comp < NUM_COMP; comp++) {
46db2becc9SJeremy L Thompson devptr[comp * compstride + tx] += sBuffer[comp][tx];
47db2becc9SJeremy L Thompson }
48db2becc9SJeremy L Thompson }
49db2becc9SJeremy L Thompson }
50db2becc9SJeremy L Thompson
51db2becc9SJeremy L Thompson ////////////////////////////////////////////////////////////////////////////////
52f80f4a74SSebastian Grimberg // read U of a 2D element into registers rU[][][] -- for all components of a single dim
53f80f4a74SSebastian Grimberg // dU is assumed to be offset by elem-stride and dim-stride
543c1e2affSSebastian Grimberg // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE]
553c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rU
569e0c01faSSebastian Grimberg // rU_SIZE can be different from P (e.g. max(P, Q))
573c1e2affSSebastian Grimberg // sTmp is a shared memory workspace of size P^2
583c1e2affSSebastian Grimberg template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM>
read_U_2d(const T * dU,const int compstride,T rU[DIM_U][NUM_COMP][rU_SIZE],T * sTmp,const int tx)599e0c01faSSebastian Grimberg static __device__ __inline__ void read_U_2d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) {
609e0c01faSSebastian Grimberg // read U as a batch P of (1 x P) vectors
613c1e2affSSebastian Grimberg // vec 0 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
623c1e2affSSebastian Grimberg // vec 1 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
63f80f4a74SSebastian Grimberg // ...
643c1e2affSSebastian Grimberg // vec P-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
65f80f4a74SSebastian Grimberg // threads collaboratively read vec0 and then vec1 and so on
66f80f4a74SSebastian Grimberg // but for the kernel, we want
67f80f4a74SSebastian Grimberg // thread 0 to hold all of vec0 in registers, and
68f80f4a74SSebastian Grimberg // thread 1 to hold all of vec1 in registers, and and so on
69f80f4a74SSebastian Grimberg // so we need to transpose
703c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) {
71f80f4a74SSebastian Grimberg // read from global memory into shared memory
723c1e2affSSebastian Grimberg if (tx < P) {
733c1e2affSSebastian Grimberg for (int i = 0; i < P; i++) {
743c1e2affSSebastian Grimberg sTmp[i * P + tx] = dU[comp * compstride + i * P + tx];
75f80f4a74SSebastian Grimberg }
76f80f4a74SSebastian Grimberg }
77f80f4a74SSebastian Grimberg __syncthreads();
78f80f4a74SSebastian Grimberg
793c1e2affSSebastian Grimberg if (tx < P) {
803c1e2affSSebastian Grimberg for (int i = 0; i < P; i++) {
813c1e2affSSebastian Grimberg rU[i_DIM][comp][i] = sTmp[tx * P + i];
82f80f4a74SSebastian Grimberg }
83f80f4a74SSebastian Grimberg }
84f80f4a74SSebastian Grimberg __syncthreads();
85f80f4a74SSebastian Grimberg }
86f80f4a74SSebastian Grimberg }
87f80f4a74SSebastian Grimberg
889e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
89f80f4a74SSebastian Grimberg // read V of a 2D element into registers rV[][][] -- for all components of a single dim
90f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride
913c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
923c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rV
939e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q))
943c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
read_V_2d(const T * dV,const int compstride,T rV[DIM_V][NUM_COMP][rV_SIZE],const int tx)959e0c01faSSebastian Grimberg static __device__ __inline__ void read_V_2d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
963c1e2affSSebastian Grimberg if (tx < Q) {
973c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) {
983c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) {
993c1e2affSSebastian Grimberg rV[i_DIM][comp][j] = dV[comp * compstride + j * Q + tx];
100f80f4a74SSebastian Grimberg }
101f80f4a74SSebastian Grimberg }
102f80f4a74SSebastian Grimberg }
103f80f4a74SSebastian Grimberg }
104f80f4a74SSebastian Grimberg
1059e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
106f80f4a74SSebastian Grimberg // write V of a 2D element from registers rV[][][] to global memory -- for all components of a single dim
107f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride
1083c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
1099e0c01faSSebastian Grimberg // i_DIM specifies which dimension is being written to in dV
1109e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q))
1113c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
write_V_2d(T * dV,const int compstride,T rV[DIM_V][NUM_COMP][rV_SIZE],const int tx)1129e0c01faSSebastian Grimberg static __device__ __inline__ void write_V_2d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
1133c1e2affSSebastian Grimberg if (tx < Q) {
1143c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) {
1153c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) {
1163c1e2affSSebastian Grimberg dV[comp * compstride + j * Q + tx] = rV[i_DIM][comp][j];
117f80f4a74SSebastian Grimberg }
118f80f4a74SSebastian Grimberg }
119f80f4a74SSebastian Grimberg }
120f80f4a74SSebastian Grimberg }
121f80f4a74SSebastian Grimberg
1229e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
123db2becc9SJeremy L Thompson // sum into V of a 2D element from registers rV[][][] to global memory -- for all components of a single dim
124db2becc9SJeremy L Thompson // dV is assumed to be offset by elem-stride and dim-stride
125db2becc9SJeremy L Thompson // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
126db2becc9SJeremy L Thompson // i_DIM specifies which dimension is being written to in dV
127db2becc9SJeremy L Thompson // rV_SIZE can be different from P (e.g. max(P, Q))
128db2becc9SJeremy L Thompson template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
sum_V_2d(T * dV,const int compstride,T rV[DIM_V][NUM_COMP][rV_SIZE],const int tx)129db2becc9SJeremy L Thompson static __device__ __inline__ void sum_V_2d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
130db2becc9SJeremy L Thompson if (tx < Q) {
131db2becc9SJeremy L Thompson for (int comp = 0; comp < NUM_COMP; comp++) {
132db2becc9SJeremy L Thompson for (int j = 0; j < Q; j++) {
133db2becc9SJeremy L Thompson dV[comp * compstride + j * Q + tx] += rV[i_DIM][comp][j];
134db2becc9SJeremy L Thompson }
135db2becc9SJeremy L Thompson }
136db2becc9SJeremy L Thompson }
137db2becc9SJeremy L Thompson }
138db2becc9SJeremy L Thompson
139db2becc9SJeremy L Thompson ////////////////////////////////////////////////////////////////////////////////
140f80f4a74SSebastian Grimberg // read U of a 3D element into registers rU[][][] -- for all components of a single dim
141f80f4a74SSebastian Grimberg // dU is assumed to be offset by elem-stride and dim-stride
1423c1e2affSSebastian Grimberg // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE]
1433c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rU
1449e0c01faSSebastian Grimberg // rU_SIZE can be different from P (e.g. max(P, Q))
1453c1e2affSSebastian Grimberg // sTmp is a shared memory workspace of size P^3
1463c1e2affSSebastian Grimberg template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM>
read_U_3d(const T * dU,const int compstride,T rU[DIM_U][NUM_COMP][rU_SIZE],T * sTmp,const int tx)1479e0c01faSSebastian Grimberg static __device__ __inline__ void read_U_3d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) {
1483c1e2affSSebastian Grimberg // read U as a batch P^2 of (1 x P_) vectors
1493c1e2affSSebastian Grimberg // vec 0 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
1503c1e2affSSebastian Grimberg // vec 1 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
151f80f4a74SSebastian Grimberg // ...
1523c1e2affSSebastian Grimberg // vec P^2-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
153f80f4a74SSebastian Grimberg // threads collaboratively read vec0 and then vec1 and so on
154f80f4a74SSebastian Grimberg // but for the kernel, we want
155f80f4a74SSebastian Grimberg // thread 0 to hold all of vec0 in registers, and
156f80f4a74SSebastian Grimberg // thread 1 to hold all of vec1 in registers, and and so on
157f80f4a74SSebastian Grimberg // so we need to transpose
1583c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) {
159f80f4a74SSebastian Grimberg // read from global memory into shared memory
1603c1e2affSSebastian Grimberg if (tx < P * P) {
1613c1e2affSSebastian Grimberg for (int i = 0; i < P; i++) {
1623c1e2affSSebastian Grimberg sTmp[i * P * P + tx] = dU[comp * compstride + i * P * P + tx];
163f80f4a74SSebastian Grimberg }
164f80f4a74SSebastian Grimberg }
165f80f4a74SSebastian Grimberg __syncthreads();
166f80f4a74SSebastian Grimberg
1673c1e2affSSebastian Grimberg if (tx < P * P) {
1683c1e2affSSebastian Grimberg for (int i = 0; i < P; i++) {
1693c1e2affSSebastian Grimberg rU[i_DIM][comp][i] = sTmp[tx * P + i];
170f80f4a74SSebastian Grimberg }
171f80f4a74SSebastian Grimberg }
172f80f4a74SSebastian Grimberg __syncthreads();
173f80f4a74SSebastian Grimberg }
174f80f4a74SSebastian Grimberg }
175f80f4a74SSebastian Grimberg
1769e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
177f80f4a74SSebastian Grimberg // read V of a 3D element into registers rV[][][] -- for all components of a single dim
178f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride
1793c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
1803c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rV
1819e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q))
1823c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
read_V_3d(const T * dV,const int compstride,T rV[DIM_V][NUM_COMP][rV_SIZE],const int tx)1839e0c01faSSebastian Grimberg static __device__ __inline__ void read_V_3d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
1843c1e2affSSebastian Grimberg if (tx < Q * Q) {
1853c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) {
1863c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) {
1873c1e2affSSebastian Grimberg rV[i_DIM][comp][j] = dV[comp * compstride + j * (Q * Q) + tx];
188f80f4a74SSebastian Grimberg }
189f80f4a74SSebastian Grimberg }
190f80f4a74SSebastian Grimberg }
191f80f4a74SSebastian Grimberg }
192f80f4a74SSebastian Grimberg
1939e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
194f80f4a74SSebastian Grimberg // write V of a 3D element from registers rV[][][] to global memory -- for all components of a single dim
195f80f4a74SSebastian Grimberg // dV is assumed to point directly to the element (i.e. already offset by elem-stride)
1963c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
1979e0c01faSSebastian Grimberg // i_DIM specifies which dimension is being written to in dV
1989e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q))
1993c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
write_V_3d(T * dV,const int compstride,T rV[DIM_V][NUM_COMP][rV_SIZE],const int tx)2009e0c01faSSebastian Grimberg static __device__ __inline__ void write_V_3d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
2013c1e2affSSebastian Grimberg if (tx < (Q * Q)) {
2023c1e2affSSebastian Grimberg for (int comp = 0; comp < NUM_COMP; comp++) {
2033c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) {
2043c1e2affSSebastian Grimberg dV[comp * compstride + j * (Q * Q) + tx] = rV[i_DIM][comp][j];
205f80f4a74SSebastian Grimberg }
206f80f4a74SSebastian Grimberg }
207f80f4a74SSebastian Grimberg }
208f80f4a74SSebastian Grimberg }
209f80f4a74SSebastian Grimberg
2109e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
211db2becc9SJeremy L Thompson // sum into V of a 3D element from registers rV[][][] to global memory -- for all components of a single dim
212db2becc9SJeremy L Thompson // dV is assumed to point directly to the element (i.e. already offset by elem-stride)
213db2becc9SJeremy L Thompson // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
214db2becc9SJeremy L Thompson // i_DIM specifies which dimension is being written to in dV
215db2becc9SJeremy L Thompson // rV_SIZE can be different from P (e.g. max(P, Q))
216db2becc9SJeremy L Thompson template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
sum_V_3d(T * dV,const int compstride,T rV[DIM_V][NUM_COMP][rV_SIZE],const int tx)217db2becc9SJeremy L Thompson static __device__ __inline__ void sum_V_3d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
218db2becc9SJeremy L Thompson if (tx < (Q * Q)) {
219db2becc9SJeremy L Thompson for (int comp = 0; comp < NUM_COMP; comp++) {
220db2becc9SJeremy L Thompson for (int j = 0; j < Q; j++) {
221db2becc9SJeremy L Thompson dV[comp * compstride + j * (Q * Q) + tx] += rV[i_DIM][comp][j];
222db2becc9SJeremy L Thompson }
223db2becc9SJeremy L Thompson }
224db2becc9SJeremy L Thompson }
225db2becc9SJeremy L Thompson }
226db2becc9SJeremy L Thompson
227db2becc9SJeremy L Thompson ////////////////////////////////////////////////////////////////////////////////
2289e0c01faSSebastian Grimberg // reads T (no-trans) into shared memory
2299e0c01faSSebastian Grimberg // T is B x J
230f80f4a74SSebastian Grimberg // must sync after call
231f80f4a74SSebastian Grimberg template <int B, int J>
read_T_notrans_gm2sm(const int tx,const CeedScalar * dT,CeedScalar * sT)2329e0c01faSSebastian Grimberg static __device__ __inline__ void read_T_notrans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) {
233f80f4a74SSebastian Grimberg if (tx < B) {
234f80f4a74SSebastian Grimberg for (int i = 0; i < J; i++) {
235f80f4a74SSebastian Grimberg sT[i * B + tx] = dT[i * B + tx];
236f80f4a74SSebastian Grimberg }
237f80f4a74SSebastian Grimberg }
2389e0c01faSSebastian Grimberg // must sync after call
2399e0c01faSSebastian Grimberg }
2409e0c01faSSebastian Grimberg
2419e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
2429e0c01faSSebastian Grimberg // reads T (trans) into shared memory
243f80f4a74SSebastian Grimberg // T is J x B
2449e0c01faSSebastian Grimberg // must sync after call
2459e0c01faSSebastian Grimberg template <int B, int J>
read_T_trans_gm2sm(const int tx,const CeedScalar * dT,CeedScalar * sT)2469e0c01faSSebastian Grimberg static __device__ __inline__ void read_T_trans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) {
247f80f4a74SSebastian Grimberg if (tx < J) {
248f80f4a74SSebastian Grimberg for (int i = 0; i < B; i++) {
249f80f4a74SSebastian Grimberg sT[tx * B + i] = dT[i * J + tx];
250f80f4a74SSebastian Grimberg }
251f80f4a74SSebastian Grimberg }
252f80f4a74SSebastian Grimberg // must sync after call
253f80f4a74SSebastian Grimberg }
254