xref: /libCEED/include/ceed/jit-source/magma/magma-common-tensor.h (revision 9dc0ea9a12d5a2dbb50983bee29c25b398979cc0)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for MAGMA backend common tensor basis definitions
10 #pragma once
11 
12 #include "magma-common-defs.h"
13 
14 ////////////////////////////////////////////////////////////////////////////////
15 // read U or V of a 1D element into shared memory sU[][] or sV[][] --  for all components
16 // the devptr is assumed to point directly to the element
17 // must sync after call
18 template <typename T, int LENGTH, int NUM_COMP>
19 static __device__ __inline__ void read_1d(const T *devptr, const int compstride, T *sBuffer[NUM_COMP], const int tx) {
20   if (tx < LENGTH) {
21     for (int comp = 0; comp < NUM_COMP; comp++) {
22       sBuffer[comp][tx] = devptr[comp * compstride + tx];
23     }
24   }
25 }
26 
27 ////////////////////////////////////////////////////////////////////////////////
28 // write V of a 1D element into global memory from sV[][] --  for all components
29 // the devptr is assumed to point directly to the element
30 template <typename T, int LENGTH, int NUM_COMP>
31 static __device__ __inline__ void write_1d(T *sBuffer[NUM_COMP], T *devptr, const int compstride, const int tx) {
32   if (tx < LENGTH) {
33     for (int comp = 0; comp < NUM_COMP; comp++) {
34       devptr[comp * compstride + tx] = sBuffer[comp][tx];
35     }
36   }
37 }
38 
39 ////////////////////////////////////////////////////////////////////////////////
40 // read U of a 2D element into registers rU[][][] --  for all components of a single dim
41 // dU is assumed to be offset by elem-stride and dim-stride
42 // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE]
43 // i_DIM specifies which dimension is being read into in rU
44 // rU_SIZE can be different from P (e.g. max(P, Q))
45 // sTmp is a shared memory workspace of size P^2
46 template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM>
47 static __device__ __inline__ void read_U_2d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) {
48   // read U as a batch P of (1 x P) vectors
49   // vec 0  : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
50   // vec 1  : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
51   // ...
52   // vec P-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
53   // threads collaboratively read vec0 and then vec1 and so on
54   // but for the kernel, we want
55   // thread 0 to hold all of vec0 in registers, and
56   // thread 1 to hold all of vec1 in registers, and and so on
57   // so we need to transpose
58   for (int comp = 0; comp < NUM_COMP; comp++) {
59     // read from global memory into shared memory
60     if (tx < P) {
61       for (int i = 0; i < P; i++) {
62         sTmp[i * P + tx] = dU[comp * compstride + i * P + tx];
63       }
64     }
65     __syncthreads();
66 
67     if (tx < P) {
68       for (int i = 0; i < P; i++) {
69         rU[i_DIM][comp][i] = sTmp[tx * P + i];
70       }
71     }
72     __syncthreads();
73   }
74 }
75 
76 ////////////////////////////////////////////////////////////////////////////////
77 // read V of a 2D element into registers rV[][][] --  for all components of a single dim
78 // dV is assumed to be offset by elem-stride and dim-stride
79 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
80 // i_DIM specifies which dimension is being read into in rV
81 // rV_SIZE can be different from P (e.g. max(P, Q))
82 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
83 static __device__ __inline__ void read_V_2d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
84   if (tx < Q) {
85     for (int comp = 0; comp < NUM_COMP; comp++) {
86       for (int j = 0; j < Q; j++) {
87         rV[i_DIM][comp][j] = dV[comp * compstride + j * Q + tx];
88       }
89     }
90   }
91 }
92 
93 ////////////////////////////////////////////////////////////////////////////////
94 // write V of a 2D element from registers rV[][][] to global memory --  for all components of a single dim
95 // dV is assumed to be offset by elem-stride and dim-stride
96 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
97 // i_DIM specifies which dimension is being written to in dV
98 // rV_SIZE can be different from P (e.g. max(P, Q))
99 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
100 static __device__ __inline__ void write_V_2d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
101   if (tx < Q) {
102     for (int comp = 0; comp < NUM_COMP; comp++) {
103       for (int j = 0; j < Q; j++) {
104         dV[comp * compstride + j * Q + tx] = rV[i_DIM][comp][j];
105       }
106     }
107   }
108 }
109 
110 ////////////////////////////////////////////////////////////////////////////////
111 // read U of a 3D element into registers rU[][][] --  for all components of a single dim
112 // dU is assumed to be offset by elem-stride and dim-stride
113 // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE]
114 // i_DIM specifies which dimension is being read into in rU
115 // rU_SIZE can be different from P (e.g. max(P, Q))
116 // sTmp is a shared memory workspace of size P^3
117 template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM>
118 static __device__ __inline__ void read_U_3d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) {
119   // read U as a batch P^2 of (1 x P_) vectors
120   // vec 0    : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
121   // vec 1    : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
122   // ...
123   // vec P^2-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
124   // threads collaboratively read vec0 and then vec1 and so on
125   // but for the kernel, we want
126   // thread 0 to hold all of vec0 in registers, and
127   // thread 1 to hold all of vec1 in registers, and and so on
128   // so we need to transpose
129   for (int comp = 0; comp < NUM_COMP; comp++) {
130     // read from global memory into shared memory
131     if (tx < P * P) {
132       for (int i = 0; i < P; i++) {
133         sTmp[i * P * P + tx] = dU[comp * compstride + i * P * P + tx];
134       }
135     }
136     __syncthreads();
137 
138     if (tx < P * P) {
139       for (int i = 0; i < P; i++) {
140         rU[i_DIM][comp][i] = sTmp[tx * P + i];
141       }
142     }
143     __syncthreads();
144   }
145 }
146 
147 ////////////////////////////////////////////////////////////////////////////////
148 // read V of a 3D element into registers rV[][][] --  for all components of a single dim
149 // dV is assumed to be offset by elem-stride and dim-stride
150 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
151 // i_DIM specifies which dimension is being read into in rV
152 // rV_SIZE can be different from P (e.g. max(P, Q))
153 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
154 static __device__ __inline__ void read_V_3d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
155   if (tx < Q * Q) {
156     for (int comp = 0; comp < NUM_COMP; comp++) {
157       for (int j = 0; j < Q; j++) {
158         rV[i_DIM][comp][j] = dV[comp * compstride + j * (Q * Q) + tx];
159       }
160     }
161   }
162 }
163 
164 ////////////////////////////////////////////////////////////////////////////////
165 // write V of a 3D element from registers rV[][][] to global memory --  for all components of a single dim
166 // dV is assumed to point directly to the element (i.e. already offset by elem-stride)
167 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
168 // i_DIM specifies which dimension is being written to in dV
169 // rV_SIZE can be different from P (e.g. max(P, Q))
170 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
171 static __device__ __inline__ void write_V_3d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
172   if (tx < (Q * Q)) {
173     for (int comp = 0; comp < NUM_COMP; comp++) {
174       for (int j = 0; j < Q; j++) {
175         dV[comp * compstride + j * (Q * Q) + tx] = rV[i_DIM][comp][j];
176       }
177     }
178   }
179 }
180 
181 ////////////////////////////////////////////////////////////////////////////////
182 // reads T (no-trans) into shared memory
183 // T is B x J
184 // must sync after call
185 template <int B, int J>
186 static __device__ __inline__ void read_T_notrans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) {
187   if (tx < B) {
188     for (int i = 0; i < J; i++) {
189       sT[i * B + tx] = dT[i * B + tx];
190     }
191   }
192   // must sync after call
193 }
194 
195 ////////////////////////////////////////////////////////////////////////////////
196 // reads T (trans) into shared memory
197 // T is J x B
198 // must sync after call
199 template <int B, int J>
200 static __device__ __inline__ void read_T_trans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) {
201   if (tx < J) {
202     for (int i = 0; i < B; i++) {
203       sT[tx * B + i] = dT[i * J + tx];
204     }
205   }
206   // must sync after call
207 }
208