xref: /libCEED/include/ceed/jit-source/magma/magma-common-tensor.h (revision f8a0df597ca176fee6b07766b6124704acaa0050)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for MAGMA backend common tensor basis definitions
10 #pragma once
11 
12 #include "magma-common-defs.h"
13 
14 ////////////////////////////////////////////////////////////////////////////////
15 // read U or V of a 1D element into shared memory sU[][] or sV[][] --  for all components
16 // the devptr is assumed to point directly to the element
17 // must sync after call
18 template <typename T, int LENGTH, int NUM_COMP>
19 static __device__ __inline__ void read_1d(const T *devptr, const int compstride, T *sBuffer[NUM_COMP], const int tx) {
20   if (tx < LENGTH) {
21     for (int comp = 0; comp < NUM_COMP; comp++) {
22       sBuffer[comp][tx] = devptr[comp * compstride + tx];
23     }
24   }
25 }
26 
27 ////////////////////////////////////////////////////////////////////////////////
28 // write V of a 1D element into global memory from sV[][] --  for all components
29 // the devptr is assumed to point directly to the element
30 template <typename T, int LENGTH, int NUM_COMP>
31 static __device__ __inline__ void write_1d(T *sBuffer[NUM_COMP], T *devptr, const int compstride, const int tx) {
32   if (tx < LENGTH) {
33     for (int comp = 0; comp < NUM_COMP; comp++) {
34       devptr[comp * compstride + tx] = sBuffer[comp][tx];
35     }
36   }
37 }
38 
39 ////////////////////////////////////////////////////////////////////////////////
40 // sum into V of a 1D element into global memory from sV[][] --  for all components
41 // the devptr is assumed to point directly to the element
42 template <typename T, int LENGTH, int NUM_COMP>
43 static __device__ __inline__ void sum_1d(T *sBuffer[NUM_COMP], T *devptr, const int compstride, const int tx) {
44   if (tx < LENGTH) {
45     for (int comp = 0; comp < NUM_COMP; comp++) {
46       devptr[comp * compstride + tx] += sBuffer[comp][tx];
47     }
48   }
49 }
50 
51 ////////////////////////////////////////////////////////////////////////////////
52 // read U of a 2D element into registers rU[][][] --  for all components of a single dim
53 // dU is assumed to be offset by elem-stride and dim-stride
54 // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE]
55 // i_DIM specifies which dimension is being read into in rU
56 // rU_SIZE can be different from P (e.g. max(P, Q))
57 // sTmp is a shared memory workspace of size P^2
58 template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM>
59 static __device__ __inline__ void read_U_2d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) {
60   // read U as a batch P of (1 x P) vectors
61   // vec 0  : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
62   // vec 1  : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
63   // ...
64   // vec P-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
65   // threads collaboratively read vec0 and then vec1 and so on
66   // but for the kernel, we want
67   // thread 0 to hold all of vec0 in registers, and
68   // thread 1 to hold all of vec1 in registers, and and so on
69   // so we need to transpose
70   for (int comp = 0; comp < NUM_COMP; comp++) {
71     // read from global memory into shared memory
72     if (tx < P) {
73       for (int i = 0; i < P; i++) {
74         sTmp[i * P + tx] = dU[comp * compstride + i * P + tx];
75       }
76     }
77     __syncthreads();
78 
79     if (tx < P) {
80       for (int i = 0; i < P; i++) {
81         rU[i_DIM][comp][i] = sTmp[tx * P + i];
82       }
83     }
84     __syncthreads();
85   }
86 }
87 
88 ////////////////////////////////////////////////////////////////////////////////
89 // read V of a 2D element into registers rV[][][] --  for all components of a single dim
90 // dV is assumed to be offset by elem-stride and dim-stride
91 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
92 // i_DIM specifies which dimension is being read into in rV
93 // rV_SIZE can be different from P (e.g. max(P, Q))
94 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
95 static __device__ __inline__ void read_V_2d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
96   if (tx < Q) {
97     for (int comp = 0; comp < NUM_COMP; comp++) {
98       for (int j = 0; j < Q; j++) {
99         rV[i_DIM][comp][j] = dV[comp * compstride + j * Q + tx];
100       }
101     }
102   }
103 }
104 
105 ////////////////////////////////////////////////////////////////////////////////
106 // write V of a 2D element from registers rV[][][] to global memory --  for all components of a single dim
107 // dV is assumed to be offset by elem-stride and dim-stride
108 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
109 // i_DIM specifies which dimension is being written to in dV
110 // rV_SIZE can be different from P (e.g. max(P, Q))
111 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
112 static __device__ __inline__ void write_V_2d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
113   if (tx < Q) {
114     for (int comp = 0; comp < NUM_COMP; comp++) {
115       for (int j = 0; j < Q; j++) {
116         dV[comp * compstride + j * Q + tx] = rV[i_DIM][comp][j];
117       }
118     }
119   }
120 }
121 
122 ////////////////////////////////////////////////////////////////////////////////
123 // sum into V of a 2D element from registers rV[][][] to global memory --  for all components of a single dim
124 // dV is assumed to be offset by elem-stride and dim-stride
125 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
126 // i_DIM specifies which dimension is being written to in dV
127 // rV_SIZE can be different from P (e.g. max(P, Q))
128 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
129 static __device__ __inline__ void sum_V_2d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
130   if (tx < Q) {
131     for (int comp = 0; comp < NUM_COMP; comp++) {
132       for (int j = 0; j < Q; j++) {
133         dV[comp * compstride + j * Q + tx] += rV[i_DIM][comp][j];
134       }
135     }
136   }
137 }
138 
139 ////////////////////////////////////////////////////////////////////////////////
140 // read U of a 3D element into registers rU[][][] --  for all components of a single dim
141 // dU is assumed to be offset by elem-stride and dim-stride
142 // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE]
143 // i_DIM specifies which dimension is being read into in rU
144 // rU_SIZE can be different from P (e.g. max(P, Q))
145 // sTmp is a shared memory workspace of size P^3
146 template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM>
147 static __device__ __inline__ void read_U_3d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) {
148   // read U as a batch P^2 of (1 x P_) vectors
149   // vec 0    : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
150   // vec 1    : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
151   // ...
152   // vec P^2-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
153   // threads collaboratively read vec0 and then vec1 and so on
154   // but for the kernel, we want
155   // thread 0 to hold all of vec0 in registers, and
156   // thread 1 to hold all of vec1 in registers, and and so on
157   // so we need to transpose
158   for (int comp = 0; comp < NUM_COMP; comp++) {
159     // read from global memory into shared memory
160     if (tx < P * P) {
161       for (int i = 0; i < P; i++) {
162         sTmp[i * P * P + tx] = dU[comp * compstride + i * P * P + tx];
163       }
164     }
165     __syncthreads();
166 
167     if (tx < P * P) {
168       for (int i = 0; i < P; i++) {
169         rU[i_DIM][comp][i] = sTmp[tx * P + i];
170       }
171     }
172     __syncthreads();
173   }
174 }
175 
176 ////////////////////////////////////////////////////////////////////////////////
177 // read V of a 3D element into registers rV[][][] --  for all components of a single dim
178 // dV is assumed to be offset by elem-stride and dim-stride
179 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
180 // i_DIM specifies which dimension is being read into in rV
181 // rV_SIZE can be different from P (e.g. max(P, Q))
182 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
183 static __device__ __inline__ void read_V_3d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
184   if (tx < Q * Q) {
185     for (int comp = 0; comp < NUM_COMP; comp++) {
186       for (int j = 0; j < Q; j++) {
187         rV[i_DIM][comp][j] = dV[comp * compstride + j * (Q * Q) + tx];
188       }
189     }
190   }
191 }
192 
193 ////////////////////////////////////////////////////////////////////////////////
194 // write V of a 3D element from registers rV[][][] to global memory --  for all components of a single dim
195 // dV is assumed to point directly to the element (i.e. already offset by elem-stride)
196 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
197 // i_DIM specifies which dimension is being written to in dV
198 // rV_SIZE can be different from P (e.g. max(P, Q))
199 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
200 static __device__ __inline__ void write_V_3d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
201   if (tx < (Q * Q)) {
202     for (int comp = 0; comp < NUM_COMP; comp++) {
203       for (int j = 0; j < Q; j++) {
204         dV[comp * compstride + j * (Q * Q) + tx] = rV[i_DIM][comp][j];
205       }
206     }
207   }
208 }
209 
210 ////////////////////////////////////////////////////////////////////////////////
211 // sum into V of a 3D element from registers rV[][][] to global memory --  for all components of a single dim
212 // dV is assumed to point directly to the element (i.e. already offset by elem-stride)
213 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
214 // i_DIM specifies which dimension is being written to in dV
215 // rV_SIZE can be different from P (e.g. max(P, Q))
216 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
217 static __device__ __inline__ void sum_V_3d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
218   if (tx < (Q * Q)) {
219     for (int comp = 0; comp < NUM_COMP; comp++) {
220       for (int j = 0; j < Q; j++) {
221         dV[comp * compstride + j * (Q * Q) + tx] += rV[i_DIM][comp][j];
222       }
223     }
224   }
225 }
226 
227 ////////////////////////////////////////////////////////////////////////////////
228 // reads T (no-trans) into shared memory
229 // T is B x J
230 // must sync after call
231 template <int B, int J>
232 static __device__ __inline__ void read_T_notrans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) {
233   if (tx < B) {
234     for (int i = 0; i < J; i++) {
235       sT[i * B + tx] = dT[i * B + tx];
236     }
237   }
238   // must sync after call
239 }
240 
241 ////////////////////////////////////////////////////////////////////////////////
242 // reads T (trans) into shared memory
243 // T is J x B
244 // must sync after call
245 template <int B, int J>
246 static __device__ __inline__ void read_T_trans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) {
247   if (tx < J) {
248     for (int i = 0; i < B; i++) {
249       sT[tx * B + i] = dT[i * J + tx];
250     }
251   }
252   // must sync after call
253 }
254