xref: /libCEED/include/ceed/jit-source/magma/magma-common-tensor.h (revision 5aed82e4fa97acf4ba24a7f10a35f5303a6798e0)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for MAGMA backend common tensor basis definitions
10 #ifndef CEED_MAGMA_COMMON_TENSOR_H
11 #define CEED_MAGMA_COMMON_TENSOR_H
12 
13 #include "magma-common-defs.h"
14 
15 ////////////////////////////////////////////////////////////////////////////////
16 // read U or V of a 1D element into shared memory sU[][] or sV[][] --  for all components
17 // the devptr is assumed to point directly to the element
18 // must sync after call
19 template <typename T, int LENGTH, int NUM_COMP>
20 static __device__ __inline__ void read_1d(const T *devptr, const int compstride, T *sBuffer[NUM_COMP], const int tx) {
21   if (tx < LENGTH) {
22     for (int comp = 0; comp < NUM_COMP; comp++) {
23       sBuffer[comp][tx] = devptr[comp * compstride + tx];
24     }
25   }
26 }
27 
28 ////////////////////////////////////////////////////////////////////////////////
29 // write V of a 1D element into global memory from sV[][] --  for all components
30 // the devptr is assumed to point directly to the element
31 template <typename T, int LENGTH, int NUM_COMP>
32 static __device__ __inline__ void write_1d(T *sBuffer[NUM_COMP], T *devptr, const int compstride, const int tx) {
33   if (tx < LENGTH) {
34     for (int comp = 0; comp < NUM_COMP; comp++) {
35       devptr[comp * compstride + tx] = sBuffer[comp][tx];
36     }
37   }
38 }
39 
40 ////////////////////////////////////////////////////////////////////////////////
41 // read U of a 2D element into registers rU[][][] --  for all components of a single dim
42 // dU is assumed to be offset by elem-stride and dim-stride
43 // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE]
44 // i_DIM specifies which dimension is being read into in rU
45 // rU_SIZE can be different from P (e.g. max(P, Q))
46 // sTmp is a shared memory workspace of size P^2
47 template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM>
48 static __device__ __inline__ void read_U_2d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) {
49   // read U as a batch P of (1 x P) vectors
50   // vec 0  : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
51   // vec 1  : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
52   // ...
53   // vec P-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
54   // threads collaboratively read vec0 and then vec1 and so on
55   // but for the kernel, we want
56   // thread 0 to hold all of vec0 in registers, and
57   // thread 1 to hold all of vec1 in registers, and and so on
58   // so we need to transpose
59   for (int comp = 0; comp < NUM_COMP; comp++) {
60     // read from global memory into shared memory
61     if (tx < P) {
62       for (int i = 0; i < P; i++) {
63         sTmp[i * P + tx] = dU[comp * compstride + i * P + tx];
64       }
65     }
66     __syncthreads();
67 
68     if (tx < P) {
69       for (int i = 0; i < P; i++) {
70         rU[i_DIM][comp][i] = sTmp[tx * P + i];
71       }
72     }
73     __syncthreads();
74   }
75 }
76 
77 ////////////////////////////////////////////////////////////////////////////////
78 // read V of a 2D element into registers rV[][][] --  for all components of a single dim
79 // dV is assumed to be offset by elem-stride and dim-stride
80 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
81 // i_DIM specifies which dimension is being read into in rV
82 // rV_SIZE can be different from P (e.g. max(P, Q))
83 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
84 static __device__ __inline__ void read_V_2d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
85   if (tx < Q) {
86     for (int comp = 0; comp < NUM_COMP; comp++) {
87       for (int j = 0; j < Q; j++) {
88         rV[i_DIM][comp][j] = dV[comp * compstride + j * Q + tx];
89       }
90     }
91   }
92 }
93 
94 ////////////////////////////////////////////////////////////////////////////////
95 // write V of a 2D element from registers rV[][][] to global memory --  for all components of a single dim
96 // dV is assumed to be offset by elem-stride and dim-stride
97 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
98 // i_DIM specifies which dimension is being written to in dV
99 // rV_SIZE can be different from P (e.g. max(P, Q))
100 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
101 static __device__ __inline__ void write_V_2d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
102   if (tx < Q) {
103     for (int comp = 0; comp < NUM_COMP; comp++) {
104       for (int j = 0; j < Q; j++) {
105         dV[comp * compstride + j * Q + tx] = rV[i_DIM][comp][j];
106       }
107     }
108   }
109 }
110 
111 ////////////////////////////////////////////////////////////////////////////////
112 // read U of a 3D element into registers rU[][][] --  for all components of a single dim
113 // dU is assumed to be offset by elem-stride and dim-stride
114 // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE]
115 // i_DIM specifies which dimension is being read into in rU
116 // rU_SIZE can be different from P (e.g. max(P, Q))
117 // sTmp is a shared memory workspace of size P^3
118 template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM>
119 static __device__ __inline__ void read_U_3d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) {
120   // read U as a batch P^2 of (1 x P_) vectors
121   // vec 0    : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
122   // vec 1    : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
123   // ...
124   // vec P^2-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
125   // threads collaboratively read vec0 and then vec1 and so on
126   // but for the kernel, we want
127   // thread 0 to hold all of vec0 in registers, and
128   // thread 1 to hold all of vec1 in registers, and and so on
129   // so we need to transpose
130   for (int comp = 0; comp < NUM_COMP; comp++) {
131     // read from global memory into shared memory
132     if (tx < P * P) {
133       for (int i = 0; i < P; i++) {
134         sTmp[i * P * P + tx] = dU[comp * compstride + i * P * P + tx];
135       }
136     }
137     __syncthreads();
138 
139     if (tx < P * P) {
140       for (int i = 0; i < P; i++) {
141         rU[i_DIM][comp][i] = sTmp[tx * P + i];
142       }
143     }
144     __syncthreads();
145   }
146 }
147 
148 ////////////////////////////////////////////////////////////////////////////////
149 // read V of a 3D element into registers rV[][][] --  for all components of a single dim
150 // dV is assumed to be offset by elem-stride and dim-stride
151 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
152 // i_DIM specifies which dimension is being read into in rV
153 // rV_SIZE can be different from P (e.g. max(P, Q))
154 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
155 static __device__ __inline__ void read_V_3d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
156   if (tx < Q * Q) {
157     for (int comp = 0; comp < NUM_COMP; comp++) {
158       for (int j = 0; j < Q; j++) {
159         rV[i_DIM][comp][j] = dV[comp * compstride + j * (Q * Q) + tx];
160       }
161     }
162   }
163 }
164 
165 ////////////////////////////////////////////////////////////////////////////////
166 // write V of a 3D element from registers rV[][][] to global memory --  for all components of a single dim
167 // dV is assumed to point directly to the element (i.e. already offset by elem-stride)
168 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
169 // i_DIM specifies which dimension is being written to in dV
170 // rV_SIZE can be different from P (e.g. max(P, Q))
171 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
172 static __device__ __inline__ void write_V_3d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
173   if (tx < (Q * Q)) {
174     for (int comp = 0; comp < NUM_COMP; comp++) {
175       for (int j = 0; j < Q; j++) {
176         dV[comp * compstride + j * (Q * Q) + tx] = rV[i_DIM][comp][j];
177       }
178     }
179   }
180 }
181 
182 ////////////////////////////////////////////////////////////////////////////////
183 // reads T (no-trans) into shared memory
184 // T is B x J
185 // must sync after call
186 template <int B, int J>
187 static __device__ __inline__ void read_T_notrans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) {
188   if (tx < B) {
189     for (int i = 0; i < J; i++) {
190       sT[i * B + tx] = dT[i * B + tx];
191     }
192   }
193   // must sync after call
194 }
195 
196 ////////////////////////////////////////////////////////////////////////////////
197 // reads T (trans) into shared memory
198 // T is J x B
199 // must sync after call
200 template <int B, int J>
201 static __device__ __inline__ void read_T_trans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) {
202   if (tx < J) {
203     for (int i = 0; i < B; i++) {
204       sT[tx * B + i] = dT[i * J + tx];
205     }
206   }
207   // must sync after call
208 }
209 
210 #endif  // CEED_MAGMA_COMMON_TENSOR_H
211