xref: /libCEED/include/ceed/jit-source/magma/magma-common-tensor.h (revision 5aed82e4fa97acf4ba24a7f10a35f5303a6798e0)
1*5aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3f80f4a74SSebastian Grimberg //
4f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5f80f4a74SSebastian Grimberg //
6f80f4a74SSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7f80f4a74SSebastian Grimberg 
83c1e2affSSebastian Grimberg /// @file
93c1e2affSSebastian Grimberg /// Internal header for MAGMA backend common tensor basis definitions
10f80f4a74SSebastian Grimberg #ifndef CEED_MAGMA_COMMON_TENSOR_H
11f80f4a74SSebastian Grimberg #define CEED_MAGMA_COMMON_TENSOR_H
12f80f4a74SSebastian Grimberg 
133c1e2affSSebastian Grimberg #include "magma-common-defs.h"
14f80f4a74SSebastian Grimberg 
159e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
16f80f4a74SSebastian Grimberg // read U or V of a 1D element into shared memory sU[][] or sV[][] --  for all components
17f80f4a74SSebastian Grimberg // the devptr is assumed to point directly to the element
18f80f4a74SSebastian Grimberg // must sync after call
193c1e2affSSebastian Grimberg template <typename T, int LENGTH, int NUM_COMP>
203c1e2affSSebastian Grimberg static __device__ __inline__ void read_1d(const T *devptr, const int compstride, T *sBuffer[NUM_COMP], const int tx) {
21f80f4a74SSebastian Grimberg   if (tx < LENGTH) {
223c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
233c1e2affSSebastian Grimberg       sBuffer[comp][tx] = devptr[comp * compstride + tx];
24f80f4a74SSebastian Grimberg     }
25f80f4a74SSebastian Grimberg   }
26f80f4a74SSebastian Grimberg }
27f80f4a74SSebastian Grimberg 
289e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
29f80f4a74SSebastian Grimberg // write V of a 1D element into global memory from sV[][] --  for all components
30f80f4a74SSebastian Grimberg // the devptr is assumed to point directly to the element
313c1e2affSSebastian Grimberg template <typename T, int LENGTH, int NUM_COMP>
323c1e2affSSebastian Grimberg static __device__ __inline__ void write_1d(T *sBuffer[NUM_COMP], T *devptr, const int compstride, const int tx) {
33f80f4a74SSebastian Grimberg   if (tx < LENGTH) {
343c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
353c1e2affSSebastian Grimberg       devptr[comp * compstride + tx] = sBuffer[comp][tx];
36f80f4a74SSebastian Grimberg     }
37f80f4a74SSebastian Grimberg   }
38f80f4a74SSebastian Grimberg }
39f80f4a74SSebastian Grimberg 
409e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
41f80f4a74SSebastian Grimberg // read U of a 2D element into registers rU[][][] --  for all components of a single dim
42f80f4a74SSebastian Grimberg // dU is assumed to be offset by elem-stride and dim-stride
433c1e2affSSebastian Grimberg // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE]
443c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rU
459e0c01faSSebastian Grimberg // rU_SIZE can be different from P (e.g. max(P, Q))
463c1e2affSSebastian Grimberg // sTmp is a shared memory workspace of size P^2
473c1e2affSSebastian Grimberg template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM>
489e0c01faSSebastian Grimberg static __device__ __inline__ void read_U_2d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) {
499e0c01faSSebastian Grimberg   // read U as a batch P of (1 x P) vectors
503c1e2affSSebastian Grimberg   // vec 0  : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
513c1e2affSSebastian Grimberg   // vec 1  : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
52f80f4a74SSebastian Grimberg   // ...
533c1e2affSSebastian Grimberg   // vec P-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
54f80f4a74SSebastian Grimberg   // threads collaboratively read vec0 and then vec1 and so on
55f80f4a74SSebastian Grimberg   // but for the kernel, we want
56f80f4a74SSebastian Grimberg   // thread 0 to hold all of vec0 in registers, and
57f80f4a74SSebastian Grimberg   // thread 1 to hold all of vec1 in registers, and and so on
58f80f4a74SSebastian Grimberg   // so we need to transpose
593c1e2affSSebastian Grimberg   for (int comp = 0; comp < NUM_COMP; comp++) {
60f80f4a74SSebastian Grimberg     // read from global memory into shared memory
613c1e2affSSebastian Grimberg     if (tx < P) {
623c1e2affSSebastian Grimberg       for (int i = 0; i < P; i++) {
633c1e2affSSebastian Grimberg         sTmp[i * P + tx] = dU[comp * compstride + i * P + tx];
64f80f4a74SSebastian Grimberg       }
65f80f4a74SSebastian Grimberg     }
66f80f4a74SSebastian Grimberg     __syncthreads();
67f80f4a74SSebastian Grimberg 
683c1e2affSSebastian Grimberg     if (tx < P) {
693c1e2affSSebastian Grimberg       for (int i = 0; i < P; i++) {
703c1e2affSSebastian Grimberg         rU[i_DIM][comp][i] = sTmp[tx * P + i];
71f80f4a74SSebastian Grimberg       }
72f80f4a74SSebastian Grimberg     }
73f80f4a74SSebastian Grimberg     __syncthreads();
74f80f4a74SSebastian Grimberg   }
75f80f4a74SSebastian Grimberg }
76f80f4a74SSebastian Grimberg 
779e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
78f80f4a74SSebastian Grimberg // read V of a 2D element into registers rV[][][] --  for all components of a single dim
79f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride
803c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
813c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rV
829e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q))
833c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
849e0c01faSSebastian Grimberg static __device__ __inline__ void read_V_2d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
853c1e2affSSebastian Grimberg   if (tx < Q) {
863c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
873c1e2affSSebastian Grimberg       for (int j = 0; j < Q; j++) {
883c1e2affSSebastian Grimberg         rV[i_DIM][comp][j] = dV[comp * compstride + j * Q + tx];
89f80f4a74SSebastian Grimberg       }
90f80f4a74SSebastian Grimberg     }
91f80f4a74SSebastian Grimberg   }
92f80f4a74SSebastian Grimberg }
93f80f4a74SSebastian Grimberg 
949e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
95f80f4a74SSebastian Grimberg // write V of a 2D element from registers rV[][][] to global memory --  for all components of a single dim
96f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride
973c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
989e0c01faSSebastian Grimberg // i_DIM specifies which dimension is being written to in dV
999e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q))
1003c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
1019e0c01faSSebastian Grimberg static __device__ __inline__ void write_V_2d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
1023c1e2affSSebastian Grimberg   if (tx < Q) {
1033c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
1043c1e2affSSebastian Grimberg       for (int j = 0; j < Q; j++) {
1053c1e2affSSebastian Grimberg         dV[comp * compstride + j * Q + tx] = rV[i_DIM][comp][j];
106f80f4a74SSebastian Grimberg       }
107f80f4a74SSebastian Grimberg     }
108f80f4a74SSebastian Grimberg   }
109f80f4a74SSebastian Grimberg }
110f80f4a74SSebastian Grimberg 
1119e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
112f80f4a74SSebastian Grimberg // read U of a 3D element into registers rU[][][] --  for all components of a single dim
113f80f4a74SSebastian Grimberg // dU is assumed to be offset by elem-stride and dim-stride
1143c1e2affSSebastian Grimberg // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE]
1153c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rU
1169e0c01faSSebastian Grimberg // rU_SIZE can be different from P (e.g. max(P, Q))
1173c1e2affSSebastian Grimberg // sTmp is a shared memory workspace of size P^3
1183c1e2affSSebastian Grimberg template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM>
1199e0c01faSSebastian Grimberg static __device__ __inline__ void read_U_3d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) {
1203c1e2affSSebastian Grimberg   // read U as a batch P^2 of (1 x P_) vectors
1213c1e2affSSebastian Grimberg   // vec 0    : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
1223c1e2affSSebastian Grimberg   // vec 1    : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
123f80f4a74SSebastian Grimberg   // ...
1243c1e2affSSebastian Grimberg   // vec P^2-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
125f80f4a74SSebastian Grimberg   // threads collaboratively read vec0 and then vec1 and so on
126f80f4a74SSebastian Grimberg   // but for the kernel, we want
127f80f4a74SSebastian Grimberg   // thread 0 to hold all of vec0 in registers, and
128f80f4a74SSebastian Grimberg   // thread 1 to hold all of vec1 in registers, and and so on
129f80f4a74SSebastian Grimberg   // so we need to transpose
1303c1e2affSSebastian Grimberg   for (int comp = 0; comp < NUM_COMP; comp++) {
131f80f4a74SSebastian Grimberg     // read from global memory into shared memory
1323c1e2affSSebastian Grimberg     if (tx < P * P) {
1333c1e2affSSebastian Grimberg       for (int i = 0; i < P; i++) {
1343c1e2affSSebastian Grimberg         sTmp[i * P * P + tx] = dU[comp * compstride + i * P * P + tx];
135f80f4a74SSebastian Grimberg       }
136f80f4a74SSebastian Grimberg     }
137f80f4a74SSebastian Grimberg     __syncthreads();
138f80f4a74SSebastian Grimberg 
1393c1e2affSSebastian Grimberg     if (tx < P * P) {
1403c1e2affSSebastian Grimberg       for (int i = 0; i < P; i++) {
1413c1e2affSSebastian Grimberg         rU[i_DIM][comp][i] = sTmp[tx * P + i];
142f80f4a74SSebastian Grimberg       }
143f80f4a74SSebastian Grimberg     }
144f80f4a74SSebastian Grimberg     __syncthreads();
145f80f4a74SSebastian Grimberg   }
146f80f4a74SSebastian Grimberg }
147f80f4a74SSebastian Grimberg 
1489e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
149f80f4a74SSebastian Grimberg // read V of a 3D element into registers rV[][][] --  for all components of a single dim
150f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride
1513c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
1523c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rV
1539e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q))
1543c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
1559e0c01faSSebastian Grimberg static __device__ __inline__ void read_V_3d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
1563c1e2affSSebastian Grimberg   if (tx < Q * Q) {
1573c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
1583c1e2affSSebastian Grimberg       for (int j = 0; j < Q; j++) {
1593c1e2affSSebastian Grimberg         rV[i_DIM][comp][j] = dV[comp * compstride + j * (Q * Q) + tx];
160f80f4a74SSebastian Grimberg       }
161f80f4a74SSebastian Grimberg     }
162f80f4a74SSebastian Grimberg   }
163f80f4a74SSebastian Grimberg }
164f80f4a74SSebastian Grimberg 
1659e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
166f80f4a74SSebastian Grimberg // write V of a 3D element from registers rV[][][] to global memory --  for all components of a single dim
167f80f4a74SSebastian Grimberg // dV is assumed to point directly to the element (i.e. already offset by elem-stride)
1683c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
1699e0c01faSSebastian Grimberg // i_DIM specifies which dimension is being written to in dV
1709e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q))
1713c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
1729e0c01faSSebastian Grimberg static __device__ __inline__ void write_V_3d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
1733c1e2affSSebastian Grimberg   if (tx < (Q * Q)) {
1743c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
1753c1e2affSSebastian Grimberg       for (int j = 0; j < Q; j++) {
1763c1e2affSSebastian Grimberg         dV[comp * compstride + j * (Q * Q) + tx] = rV[i_DIM][comp][j];
177f80f4a74SSebastian Grimberg       }
178f80f4a74SSebastian Grimberg     }
179f80f4a74SSebastian Grimberg   }
180f80f4a74SSebastian Grimberg }
181f80f4a74SSebastian Grimberg 
1829e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
1839e0c01faSSebastian Grimberg // reads T (no-trans) into shared memory
1849e0c01faSSebastian Grimberg // T is B x J
185f80f4a74SSebastian Grimberg // must sync after call
186f80f4a74SSebastian Grimberg template <int B, int J>
1879e0c01faSSebastian Grimberg static __device__ __inline__ void read_T_notrans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) {
188f80f4a74SSebastian Grimberg   if (tx < B) {
189f80f4a74SSebastian Grimberg     for (int i = 0; i < J; i++) {
190f80f4a74SSebastian Grimberg       sT[i * B + tx] = dT[i * B + tx];
191f80f4a74SSebastian Grimberg     }
192f80f4a74SSebastian Grimberg   }
1939e0c01faSSebastian Grimberg   // must sync after call
1949e0c01faSSebastian Grimberg }
1959e0c01faSSebastian Grimberg 
1969e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
1979e0c01faSSebastian Grimberg // reads T (trans) into shared memory
198f80f4a74SSebastian Grimberg // T is J x B
1999e0c01faSSebastian Grimberg // must sync after call
2009e0c01faSSebastian Grimberg template <int B, int J>
2019e0c01faSSebastian Grimberg static __device__ __inline__ void read_T_trans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) {
202f80f4a74SSebastian Grimberg   if (tx < J) {
203f80f4a74SSebastian Grimberg     for (int i = 0; i < B; i++) {
204f80f4a74SSebastian Grimberg       sT[tx * B + i] = dT[i * J + tx];
205f80f4a74SSebastian Grimberg     }
206f80f4a74SSebastian Grimberg   }
207f80f4a74SSebastian Grimberg   // must sync after call
208f80f4a74SSebastian Grimberg }
209f80f4a74SSebastian Grimberg 
210f80f4a74SSebastian Grimberg #endif  // CEED_MAGMA_COMMON_TENSOR_H
211