xref: /libCEED/include/ceed/jit-source/magma/magma-common-tensor.h (revision db2becc9f302fe8eb3a32ace50ce3f3a5d42e6c4)
15aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3f80f4a74SSebastian Grimberg //
4f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5f80f4a74SSebastian Grimberg //
6f80f4a74SSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7f80f4a74SSebastian Grimberg 
83c1e2affSSebastian Grimberg /// @file
93c1e2affSSebastian Grimberg /// Internal header for MAGMA backend common tensor basis definitions
10509d4af6SJeremy L Thompson #pragma once
11f80f4a74SSebastian Grimberg 
123c1e2affSSebastian Grimberg #include "magma-common-defs.h"
13f80f4a74SSebastian Grimberg 
149e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
15f80f4a74SSebastian Grimberg // read U or V of a 1D element into shared memory sU[][] or sV[][] --  for all components
16f80f4a74SSebastian Grimberg // the devptr is assumed to point directly to the element
17f80f4a74SSebastian Grimberg // must sync after call
183c1e2affSSebastian Grimberg template <typename T, int LENGTH, int NUM_COMP>
193c1e2affSSebastian Grimberg static __device__ __inline__ void read_1d(const T *devptr, const int compstride, T *sBuffer[NUM_COMP], const int tx) {
20f80f4a74SSebastian Grimberg   if (tx < LENGTH) {
213c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
223c1e2affSSebastian Grimberg       sBuffer[comp][tx] = devptr[comp * compstride + tx];
23f80f4a74SSebastian Grimberg     }
24f80f4a74SSebastian Grimberg   }
25f80f4a74SSebastian Grimberg }
26f80f4a74SSebastian Grimberg 
279e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
28f80f4a74SSebastian Grimberg // write V of a 1D element into global memory from sV[][] --  for all components
29f80f4a74SSebastian Grimberg // the devptr is assumed to point directly to the element
303c1e2affSSebastian Grimberg template <typename T, int LENGTH, int NUM_COMP>
313c1e2affSSebastian Grimberg static __device__ __inline__ void write_1d(T *sBuffer[NUM_COMP], T *devptr, const int compstride, const int tx) {
32f80f4a74SSebastian Grimberg   if (tx < LENGTH) {
333c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
343c1e2affSSebastian Grimberg       devptr[comp * compstride + tx] = sBuffer[comp][tx];
35f80f4a74SSebastian Grimberg     }
36f80f4a74SSebastian Grimberg   }
37f80f4a74SSebastian Grimberg }
38f80f4a74SSebastian Grimberg 
399e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
40*db2becc9SJeremy L Thompson // sum into V of a 1D element into global memory from sV[][] --  for all components
41*db2becc9SJeremy L Thompson // the devptr is assumed to point directly to the element
42*db2becc9SJeremy L Thompson template <typename T, int LENGTH, int NUM_COMP>
43*db2becc9SJeremy L Thompson static __device__ __inline__ void sum_1d(T *sBuffer[NUM_COMP], T *devptr, const int compstride, const int tx) {
44*db2becc9SJeremy L Thompson   if (tx < LENGTH) {
45*db2becc9SJeremy L Thompson     for (int comp = 0; comp < NUM_COMP; comp++) {
46*db2becc9SJeremy L Thompson       devptr[comp * compstride + tx] += sBuffer[comp][tx];
47*db2becc9SJeremy L Thompson     }
48*db2becc9SJeremy L Thompson   }
49*db2becc9SJeremy L Thompson }
50*db2becc9SJeremy L Thompson 
51*db2becc9SJeremy L Thompson ////////////////////////////////////////////////////////////////////////////////
52f80f4a74SSebastian Grimberg // read U of a 2D element into registers rU[][][] --  for all components of a single dim
53f80f4a74SSebastian Grimberg // dU is assumed to be offset by elem-stride and dim-stride
543c1e2affSSebastian Grimberg // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE]
553c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rU
569e0c01faSSebastian Grimberg // rU_SIZE can be different from P (e.g. max(P, Q))
573c1e2affSSebastian Grimberg // sTmp is a shared memory workspace of size P^2
583c1e2affSSebastian Grimberg template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM>
599e0c01faSSebastian Grimberg static __device__ __inline__ void read_U_2d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) {
609e0c01faSSebastian Grimberg   // read U as a batch P of (1 x P) vectors
613c1e2affSSebastian Grimberg   // vec 0  : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
623c1e2affSSebastian Grimberg   // vec 1  : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
63f80f4a74SSebastian Grimberg   // ...
643c1e2affSSebastian Grimberg   // vec P-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
65f80f4a74SSebastian Grimberg   // threads collaboratively read vec0 and then vec1 and so on
66f80f4a74SSebastian Grimberg   // but for the kernel, we want
67f80f4a74SSebastian Grimberg   // thread 0 to hold all of vec0 in registers, and
68f80f4a74SSebastian Grimberg   // thread 1 to hold all of vec1 in registers, and and so on
69f80f4a74SSebastian Grimberg   // so we need to transpose
703c1e2affSSebastian Grimberg   for (int comp = 0; comp < NUM_COMP; comp++) {
71f80f4a74SSebastian Grimberg     // read from global memory into shared memory
723c1e2affSSebastian Grimberg     if (tx < P) {
733c1e2affSSebastian Grimberg       for (int i = 0; i < P; i++) {
743c1e2affSSebastian Grimberg         sTmp[i * P + tx] = dU[comp * compstride + i * P + tx];
75f80f4a74SSebastian Grimberg       }
76f80f4a74SSebastian Grimberg     }
77f80f4a74SSebastian Grimberg     __syncthreads();
78f80f4a74SSebastian Grimberg 
793c1e2affSSebastian Grimberg     if (tx < P) {
803c1e2affSSebastian Grimberg       for (int i = 0; i < P; i++) {
813c1e2affSSebastian Grimberg         rU[i_DIM][comp][i] = sTmp[tx * P + i];
82f80f4a74SSebastian Grimberg       }
83f80f4a74SSebastian Grimberg     }
84f80f4a74SSebastian Grimberg     __syncthreads();
85f80f4a74SSebastian Grimberg   }
86f80f4a74SSebastian Grimberg }
87f80f4a74SSebastian Grimberg 
889e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
89f80f4a74SSebastian Grimberg // read V of a 2D element into registers rV[][][] --  for all components of a single dim
90f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride
913c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
923c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rV
939e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q))
943c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
959e0c01faSSebastian Grimberg static __device__ __inline__ void read_V_2d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
963c1e2affSSebastian Grimberg   if (tx < Q) {
973c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
983c1e2affSSebastian Grimberg       for (int j = 0; j < Q; j++) {
993c1e2affSSebastian Grimberg         rV[i_DIM][comp][j] = dV[comp * compstride + j * Q + tx];
100f80f4a74SSebastian Grimberg       }
101f80f4a74SSebastian Grimberg     }
102f80f4a74SSebastian Grimberg   }
103f80f4a74SSebastian Grimberg }
104f80f4a74SSebastian Grimberg 
1059e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
106f80f4a74SSebastian Grimberg // write V of a 2D element from registers rV[][][] to global memory --  for all components of a single dim
107f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride
1083c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
1099e0c01faSSebastian Grimberg // i_DIM specifies which dimension is being written to in dV
1109e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q))
1113c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
1129e0c01faSSebastian Grimberg static __device__ __inline__ void write_V_2d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
1133c1e2affSSebastian Grimberg   if (tx < Q) {
1143c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
1153c1e2affSSebastian Grimberg       for (int j = 0; j < Q; j++) {
1163c1e2affSSebastian Grimberg         dV[comp * compstride + j * Q + tx] = rV[i_DIM][comp][j];
117f80f4a74SSebastian Grimberg       }
118f80f4a74SSebastian Grimberg     }
119f80f4a74SSebastian Grimberg   }
120f80f4a74SSebastian Grimberg }
121f80f4a74SSebastian Grimberg 
1229e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
123*db2becc9SJeremy L Thompson // sum into V of a 2D element from registers rV[][][] to global memory --  for all components of a single dim
124*db2becc9SJeremy L Thompson // dV is assumed to be offset by elem-stride and dim-stride
125*db2becc9SJeremy L Thompson // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
126*db2becc9SJeremy L Thompson // i_DIM specifies which dimension is being written to in dV
127*db2becc9SJeremy L Thompson // rV_SIZE can be different from P (e.g. max(P, Q))
128*db2becc9SJeremy L Thompson template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
129*db2becc9SJeremy L Thompson static __device__ __inline__ void sum_V_2d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
130*db2becc9SJeremy L Thompson   if (tx < Q) {
131*db2becc9SJeremy L Thompson     for (int comp = 0; comp < NUM_COMP; comp++) {
132*db2becc9SJeremy L Thompson       for (int j = 0; j < Q; j++) {
133*db2becc9SJeremy L Thompson         dV[comp * compstride + j * Q + tx] += rV[i_DIM][comp][j];
134*db2becc9SJeremy L Thompson       }
135*db2becc9SJeremy L Thompson     }
136*db2becc9SJeremy L Thompson   }
137*db2becc9SJeremy L Thompson }
138*db2becc9SJeremy L Thompson 
139*db2becc9SJeremy L Thompson ////////////////////////////////////////////////////////////////////////////////
140f80f4a74SSebastian Grimberg // read U of a 3D element into registers rU[][][] --  for all components of a single dim
141f80f4a74SSebastian Grimberg // dU is assumed to be offset by elem-stride and dim-stride
1423c1e2affSSebastian Grimberg // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE]
1433c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rU
1449e0c01faSSebastian Grimberg // rU_SIZE can be different from P (e.g. max(P, Q))
1453c1e2affSSebastian Grimberg // sTmp is a shared memory workspace of size P^3
1463c1e2affSSebastian Grimberg template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM>
1479e0c01faSSebastian Grimberg static __device__ __inline__ void read_U_3d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) {
1483c1e2affSSebastian Grimberg   // read U as a batch P^2 of (1 x P_) vectors
1493c1e2affSSebastian Grimberg   // vec 0    : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
1503c1e2affSSebastian Grimberg   // vec 1    : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
151f80f4a74SSebastian Grimberg   // ...
1523c1e2affSSebastian Grimberg   // vec P^2-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
153f80f4a74SSebastian Grimberg   // threads collaboratively read vec0 and then vec1 and so on
154f80f4a74SSebastian Grimberg   // but for the kernel, we want
155f80f4a74SSebastian Grimberg   // thread 0 to hold all of vec0 in registers, and
156f80f4a74SSebastian Grimberg   // thread 1 to hold all of vec1 in registers, and and so on
157f80f4a74SSebastian Grimberg   // so we need to transpose
1583c1e2affSSebastian Grimberg   for (int comp = 0; comp < NUM_COMP; comp++) {
159f80f4a74SSebastian Grimberg     // read from global memory into shared memory
1603c1e2affSSebastian Grimberg     if (tx < P * P) {
1613c1e2affSSebastian Grimberg       for (int i = 0; i < P; i++) {
1623c1e2affSSebastian Grimberg         sTmp[i * P * P + tx] = dU[comp * compstride + i * P * P + tx];
163f80f4a74SSebastian Grimberg       }
164f80f4a74SSebastian Grimberg     }
165f80f4a74SSebastian Grimberg     __syncthreads();
166f80f4a74SSebastian Grimberg 
1673c1e2affSSebastian Grimberg     if (tx < P * P) {
1683c1e2affSSebastian Grimberg       for (int i = 0; i < P; i++) {
1693c1e2affSSebastian Grimberg         rU[i_DIM][comp][i] = sTmp[tx * P + i];
170f80f4a74SSebastian Grimberg       }
171f80f4a74SSebastian Grimberg     }
172f80f4a74SSebastian Grimberg     __syncthreads();
173f80f4a74SSebastian Grimberg   }
174f80f4a74SSebastian Grimberg }
175f80f4a74SSebastian Grimberg 
1769e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
177f80f4a74SSebastian Grimberg // read V of a 3D element into registers rV[][][] --  for all components of a single dim
178f80f4a74SSebastian Grimberg // dV is assumed to be offset by elem-stride and dim-stride
1793c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
1803c1e2affSSebastian Grimberg // i_DIM specifies which dimension is being read into in rV
1819e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q))
1823c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
1839e0c01faSSebastian Grimberg static __device__ __inline__ void read_V_3d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
1843c1e2affSSebastian Grimberg   if (tx < Q * Q) {
1853c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
1863c1e2affSSebastian Grimberg       for (int j = 0; j < Q; j++) {
1873c1e2affSSebastian Grimberg         rV[i_DIM][comp][j] = dV[comp * compstride + j * (Q * Q) + tx];
188f80f4a74SSebastian Grimberg       }
189f80f4a74SSebastian Grimberg     }
190f80f4a74SSebastian Grimberg   }
191f80f4a74SSebastian Grimberg }
192f80f4a74SSebastian Grimberg 
1939e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
194f80f4a74SSebastian Grimberg // write V of a 3D element from registers rV[][][] to global memory --  for all components of a single dim
195f80f4a74SSebastian Grimberg // dV is assumed to point directly to the element (i.e. already offset by elem-stride)
1963c1e2affSSebastian Grimberg // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
1979e0c01faSSebastian Grimberg // i_DIM specifies which dimension is being written to in dV
1989e0c01faSSebastian Grimberg // rV_SIZE can be different from P (e.g. max(P, Q))
1993c1e2affSSebastian Grimberg template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
2009e0c01faSSebastian Grimberg static __device__ __inline__ void write_V_3d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
2013c1e2affSSebastian Grimberg   if (tx < (Q * Q)) {
2023c1e2affSSebastian Grimberg     for (int comp = 0; comp < NUM_COMP; comp++) {
2033c1e2affSSebastian Grimberg       for (int j = 0; j < Q; j++) {
2043c1e2affSSebastian Grimberg         dV[comp * compstride + j * (Q * Q) + tx] = rV[i_DIM][comp][j];
205f80f4a74SSebastian Grimberg       }
206f80f4a74SSebastian Grimberg     }
207f80f4a74SSebastian Grimberg   }
208f80f4a74SSebastian Grimberg }
209f80f4a74SSebastian Grimberg 
2109e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
211*db2becc9SJeremy L Thompson // sum into V of a 3D element from registers rV[][][] to global memory --  for all components of a single dim
212*db2becc9SJeremy L Thompson // dV is assumed to point directly to the element (i.e. already offset by elem-stride)
213*db2becc9SJeremy L Thompson // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
214*db2becc9SJeremy L Thompson // i_DIM specifies which dimension is being written to in dV
215*db2becc9SJeremy L Thompson // rV_SIZE can be different from P (e.g. max(P, Q))
216*db2becc9SJeremy L Thompson template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
217*db2becc9SJeremy L Thompson static __device__ __inline__ void sum_V_3d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
218*db2becc9SJeremy L Thompson   if (tx < (Q * Q)) {
219*db2becc9SJeremy L Thompson     for (int comp = 0; comp < NUM_COMP; comp++) {
220*db2becc9SJeremy L Thompson       for (int j = 0; j < Q; j++) {
221*db2becc9SJeremy L Thompson         dV[comp * compstride + j * (Q * Q) + tx] += rV[i_DIM][comp][j];
222*db2becc9SJeremy L Thompson       }
223*db2becc9SJeremy L Thompson     }
224*db2becc9SJeremy L Thompson   }
225*db2becc9SJeremy L Thompson }
226*db2becc9SJeremy L Thompson 
227*db2becc9SJeremy L Thompson ////////////////////////////////////////////////////////////////////////////////
2289e0c01faSSebastian Grimberg // reads T (no-trans) into shared memory
2299e0c01faSSebastian Grimberg // T is B x J
230f80f4a74SSebastian Grimberg // must sync after call
231f80f4a74SSebastian Grimberg template <int B, int J>
2329e0c01faSSebastian Grimberg static __device__ __inline__ void read_T_notrans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) {
233f80f4a74SSebastian Grimberg   if (tx < B) {
234f80f4a74SSebastian Grimberg     for (int i = 0; i < J; i++) {
235f80f4a74SSebastian Grimberg       sT[i * B + tx] = dT[i * B + tx];
236f80f4a74SSebastian Grimberg     }
237f80f4a74SSebastian Grimberg   }
2389e0c01faSSebastian Grimberg   // must sync after call
2399e0c01faSSebastian Grimberg }
2409e0c01faSSebastian Grimberg 
2419e0c01faSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
2429e0c01faSSebastian Grimberg // reads T (trans) into shared memory
243f80f4a74SSebastian Grimberg // T is J x B
2449e0c01faSSebastian Grimberg // must sync after call
2459e0c01faSSebastian Grimberg template <int B, int J>
2469e0c01faSSebastian Grimberg static __device__ __inline__ void read_T_trans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) {
247f80f4a74SSebastian Grimberg   if (tx < J) {
248f80f4a74SSebastian Grimberg     for (int i = 0; i < B; i++) {
249f80f4a74SSebastian Grimberg       sT[tx * B + i] = dT[i * J + tx];
250f80f4a74SSebastian Grimberg     }
251f80f4a74SSebastian Grimberg   }
252f80f4a74SSebastian Grimberg   // must sync after call
253f80f4a74SSebastian Grimberg }
254