xref: /libCEED/include/ceed/jit-source/magma/magma-common-tensor.h (revision f80f4a748154eed4bc661c135f695b92b1bc45b9)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #ifndef CEED_MAGMA_COMMON_TENSOR_H
9 #define CEED_MAGMA_COMMON_TENSOR_H
10 
11 #define MAGMA_MAXTHREADS_1D 128
12 #define MAGMA_MAXTHREADS_2D 128
13 #define MAGMA_MAXTHREADS_3D 64
14 // Define macro for determining number of threads in y-direction
15 // for basis kernels
16 #define MAGMA_BASIS_NTCOL(x, maxt) (((maxt) < (x)) ? 1 : ((maxt) / (x)))
17 // Define macro for computing the total threads in a block
18 // for use with __launch_bounds__()
19 #define MAGMA_BASIS_BOUNDS(x, maxt) (x * MAGMA_BASIS_NTCOL(x, maxt))
20 
21 //////////////////////////////////////////////////////////////////////////////////////////
22 // read U or V of a 1D element into shared memory sU[][] or sV[][] --  for all components
23 // the devptr is assumed to point directly to the element
24 // must sync after call
25 template <typename T, int LENGTH, int NCOMP_>
26 __device__ __inline__ void read_1d(const T *devptr, const int compstride, T *sBuffer[NCOMP_], const int tx) {
27   if (tx < LENGTH) {
28     for (int icomp = 0; icomp < NCOMP_; icomp++) {
29       sBuffer[icomp][tx] = devptr[icomp * compstride + tx];
30     }
31   }
32 }
33 
34 //////////////////////////////////////////////////////////////////////////////////////////
35 // write V of a 1D element into global memory from sV[][] --  for all components
36 // the devptr is assumed to point directly to the element
37 template <typename T, int LENGTH, int NCOMP_>
38 __device__ __inline__ void write_1d(T *sBuffer[NCOMP_], T *devptr, const int compstride, const int tx) {
39   if (tx < LENGTH) {
40     for (int icomp = 0; icomp < NCOMP_; icomp++) {
41       devptr[icomp * compstride + tx] = sBuffer[icomp][tx];
42     }
43   }
44 }
45 
46 //////////////////////////////////////////////////////////////////////////////////////////
47 // read U of a 2D element into registers rU[][][] --  for all components of a single dim
48 // dU is assumed to be offset by elem-stride and dim-stride
49 // register is assumed to be rU[DIMU][NCOMP_][rUsize]
50 // iDIM specifies which dimension is being read into in rU
51 // rUsize can be different from P_ (e.g. MAXP_Q)
52 // sTmp is a shared memory workspace of size P_^2
53 template <typename T, int P_, int DIMU, int NCOMP_, int rUsize, int iDIM>
54 __device__ __inline__ void readU_2d(const T *dU, const int compstride, T rU[DIMU][NCOMP_][rUsize], T *sTmp, const int tx) {
55   // read U as a batch P_ of (1xP_) vectors
56   // vec 0  : [u0, u1, u2, ... u_(P_-1)] -- contiguous in memory
57   // vec 1  : [u0, u1, u2, ... u_(P_-1)] -- contiguous in memory
58   // ...
59   // vec P_-1: [u0, u1, u2, ... u_(P_-1)] -- contiguous in memory
60   // threads collaboratively read vec0 and then vec1 and so on
61   // but for the kernel, we want
62   // thread 0 to hold all of vec0 in registers, and
63   // thread 1 to hold all of vec1 in registers, and and so on
64   // so we need to transpose
65   for (int icomp = 0; icomp < NCOMP_; icomp++) {
66     // read from global memory into shared memory
67     if (tx < P_) {
68       for (int i = 0; i < P_; i++) {
69         sTmp[i * P_ + tx] = dU[icomp * compstride + i * P_ + tx];
70       }
71     }
72     __syncthreads();
73 
74     if (tx < P_) {
75       for (int i = 0; i < P_; i++) {
76         rU[iDIM][icomp][i] = sTmp[tx * P_ + i];
77       }
78     }
79     __syncthreads();
80   }
81 }
82 
83 //////////////////////////////////////////////////////////////////////////////////////////
84 // read V of a 2D element into registers rV[][][] --  for all components of a single dim
85 // dV is assumed to be offset by elem-stride and dim-stride
86 // register is assumed to be rV[DIMV][NCOMP_][rVsize]
87 // iDIM specifies which dimension is being read into in rV
88 // rVsize can be different from P_ (e.g. MAXP_Q)
89 template <typename T, int Q_, int DIMV, int NCOMP_, int rVsize, int iDIM>
90 __device__ __inline__ void readV_2d(const T *dV, const int compstride, T rV[DIMV][NCOMP_][rVsize], const int tx) {
91   if (tx < Q_) {
92     for (int icomp = 0; icomp < NCOMP_; icomp++) {
93       for (int j = 0; j < Q_; j++) {
94         rV[iDIM][icomp][j] = dV[icomp * compstride + j * Q_ + tx];
95       }
96     }
97   }
98 }
99 
100 //////////////////////////////////////////////////////////////////////////////////////////
101 // write V of a 2D element from registers rV[][][] to global memory --  for all components of a single dim
102 // dV is assumed to be offset by elem-stride and dim-stride
103 // register is assumed to be rV[DIMV][NCOMP_][rVsize]
104 // iDIM specifies which dimension is being read from in rV
105 // idim specifies which dimension is being written to in dV
106 // rVsize can be different from P_ (e.g. MAXP_Q)
107 template <typename T, int Q_, int DIMV, int NCOMP_, int rVsize, int iDIM>
108 __device__ __inline__ void writeV_2d(T *dV, const int compstride, T rV[DIMV][NCOMP_][rVsize], const int tx) {
109   if (tx < Q_) {
110     for (int icomp = 0; icomp < NCOMP_; icomp++) {
111       for (int j = 0; j < Q_; j++) {
112         dV[icomp * compstride + j * Q_ + tx] = rV[iDIM][icomp][j];
113       }
114     }
115   }
116 }
117 
118 //////////////////////////////////////////////////////////////////////////////////////////
119 // read U of a 3D element into registers rU[][][] --  for all components of a single dim
120 // dU is assumed to be offset by elem-stride and dim-stride
121 // register is assumed to be rU[DIMU][NCOMP_][rUsize]
122 // iDIM specifies which dimension is being read into in rU
123 // rUsize can be different from P_ (e.g. MAXP_Q)
124 // sTmp is a shared memory workspace of size P_^3
125 template <typename T, int P_, int DIMU, int NCOMP_, int rUsize, int iDIM>
126 __device__ __inline__ void readU_3d(const T *dU, const int compstride, T rU[DIMU][NCOMP_][rUsize], T *sTmp, const int tx) {
127   // read U as a batch P_^2 of (1xP_) vectors
128   // vec 0    : [u0, u1, u2, ... u_(P_-1)] -- contiguous in memory
129   // vec 1    : [u0, u1, u2, ... u_(P_-1)] -- contiguous in memory
130   // ...
131   // vec P_^2-1: [u0, u1, u2, ... u_(P_-1)] -- contiguous in memory
132   // threads collaboratively read vec0 and then vec1 and so on
133   // but for the kernel, we want
134   // thread 0 to hold all of vec0 in registers, and
135   // thread 1 to hold all of vec1 in registers, and and so on
136   // so we need to transpose
137   for (int icomp = 0; icomp < NCOMP_; icomp++) {
138     // read from global memory into shared memory
139     if (tx < P_ * P_) {
140       for (int i = 0; i < P_; i++) {
141         sTmp[i * P_ * P_ + tx] = dU[icomp * compstride + i * P_ * P_ + tx];
142       }
143     }
144     __syncthreads();
145 
146     if (tx < P_ * P_) {
147       for (int i = 0; i < P_; i++) {
148         rU[iDIM][icomp][i] = sTmp[tx * P_ + i];
149       }
150     }
151     __syncthreads();
152   }
153 }
154 
155 //////////////////////////////////////////////////////////////////////////////////////////
156 // read V of a 3D element into registers rV[][][] --  for all components of a single dim
157 // dV is assumed to be offset by elem-stride and dim-stride
158 // register is assumed to be rV[DIMV][NCOMP_][rVsize]
159 // iDIM specifies which dimension is being read into in rV
160 // rVsize can be different from P_ (e.g. MAXP_Q)
161 template <typename T, int Q_, int DIMV, int NCOMP_, int rVsize, int iDIM>
162 __device__ __inline__ void readV_3d(const T *dV, const int compstride, T rV[DIMV][NCOMP_][rVsize], const int tx) {
163   if (tx < Q_ * Q_) {
164     for (int icomp = 0; icomp < NCOMP_; icomp++) {
165       for (int j = 0; j < Q_; j++) {
166         rV[iDIM][icomp][j] = dV[icomp * compstride + j * (Q_ * Q_) + tx];
167       }
168     }
169   }
170 }
171 
172 //////////////////////////////////////////////////////////////////////////////////////////
173 // write V of a 3D element from registers rV[][][] to global memory --  for all components of a single dim
174 // dV is assumed to point directly to the element (i.e. already offset by elem-stride)
175 // register is assumed to be rV[DIMV][NCOMP_][rVsize]
176 // iDIM specifies which dimension is being read from in rV
177 // idim specifies which dimension is being written to in dV
178 // rVsize can be different from P_ (e.g. MAXP_Q)
179 template <typename T, int Q_, int DIMV, int NCOMP_, int rVsize, int iDIM>
180 __device__ __inline__ void writeV_3d(T *dV, const int compstride, T rV[DIMV][NCOMP_][rVsize], const int tx) {
181   if (tx < (Q_ * Q_)) {
182     for (int icomp = 0; icomp < NCOMP_; icomp++) {
183       for (int j = 0; j < Q_; j++) {
184         dV[icomp * compstride + j * (Q_ * Q_) + tx] = rV[iDIM][icomp][j];
185       }
186     }
187   }
188 }
189 
190 //////////////////////////////////////////////////////////////////////////////////////////
191 // reads T into shared memory
192 // must sync after call
193 template <int B, int J>
194 __device__ __inline__ void dread_T_gm2sm(const int tx, const magma_trans_t transT, const CeedScalar *dT, CeedScalar *sT) {
195   if (transT == MagmaNoTrans) {
196     // T is B x J
197     if (tx < B) {
198       for (int i = 0; i < J; i++) {
199         sT[i * B + tx] = dT[i * B + tx];
200       }
201     }
202   } else {
203     // T is J x B
204     if (tx < J) {
205       for (int i = 0; i < B; i++) {
206         sT[tx * B + i] = dT[i * J + tx];
207       }
208     }
209   }
210   // must sync after call
211 }
212 
213 //////////////////////////////////////////////////////////////////////////////////////////
214 // reads a slice of U from shared/global memory into registers
215 // the correct pointer U must be precomputed
216 template <int B>
217 __device__ __inline__ void dread_U_gsm2reg(const int C, const int tx_, const CeedScalar *U, CeedScalar rU[B]) {
218   for (int i = 0; i < B; i++) {
219     rU[i] = U[i * C + tx_];
220   }
221 }
222 
223 //////////////////////////////////////////////////////////////////////////////////////////
224 // reads a slice of V from shared/global memory into registers with scaling
225 // the correct pointer V must be precomputed
226 template <int J>
227 __device__ __inline__ void dread_V_gsm2reg(const int C, const int tx_, const CeedScalar *V, CeedScalar rV[J]) {
228   for (int i = 0; i < J; i++) {
229     rV[i] = V[i * C + tx_];
230   }
231 }
232 
233 //////////////////////////////////////////////////////////////////////////////////////////
234 // writes a slice of V from reg to shared/global memory
235 // the correct pointer V must be precomputed
236 template <int J>
237 __device__ __inline__ void dwrite_V_reg2gsm(const int C, const int tx_, CeedScalar rV[J], CeedScalar *V) {
238   for (int i = 0; i < J; i++) {
239     V[i * C + tx_] = rV[i];
240   }
241 }
242 
243 //////////////////////////////////////////////////////////////////////////////////////////
244 // multiply a slice of U times T to produce a slice of V
245 template <int B, int J>
246 __device__ __inline__ void dgemm_slice(CeedScalar alpha, CeedScalar *sT, CeedScalar rU[B], CeedScalar beta, CeedScalar rV[J]) {
247   CeedScalar rTmp;
248   for (int j = 0; j < J; j++) {
249     rTmp = 0.0;
250     for (int b = 0; b < B; b++) {
251       rTmp += rU[b] * sT[j * B + b];
252     }
253     rV[j] *= beta;
254     rV[j] += alpha * rTmp;
255   }
256 }
257 
258 //////////////////////////////////////////////////////////////////////////////////////////
259 template <int B, int J>
260 __device__ __inline__ void dgemm_ceed_device(const int tx, const int A, const int C, magma_trans_t transT, CeedScalar *sT, const CeedScalar alpha,
261                                              const CeedScalar beta, const CeedScalar *dU, CeedScalar *dV, CeedScalar rU[B], CeedScalar rV[J]) {
262   const int tx_      = tx % C;
263   const int slice_id = tx / C;
264 
265   // advance pointers for U and V
266   dU += slice_id * C * B;
267   dV += slice_id * C * J;
268 
269   // read V if beta is non-zero
270   if (beta != 0.0) {
271     dread_V_gsm2reg<J>(C, tx_, (const CeedScalar *)dV, rV);
272   }
273 
274   // read U
275   dread_U_gsm2reg<B>(C, tx_, dU, rU);
276 
277   // multiply
278   dgemm_slice<B, J>(alpha, sT, rU, beta, rV);
279 
280   // write V back
281   dwrite_V_reg2gsm<J>(C, tx_, rV, dV);
282 }
283 
284 #endif  // CEED_MAGMA_COMMON_TENSOR_H
285