xref: /libCEED/include/ceed/jit-source/magma/magma-common-tensor.h (revision 715f9ba89a309f24226005ca1fbb9f59fe9eac68)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for MAGMA backend common tensor basis definitions
10 #ifndef CEED_MAGMA_COMMON_TENSOR_H
11 #define CEED_MAGMA_COMMON_TENSOR_H
12 
13 #include "magma-common-defs.h"
14 
15 //////////////////////////////////////////////////////////////////////////////////////////
16 // read U or V of a 1D element into shared memory sU[][] or sV[][] --  for all components
17 // the devptr is assumed to point directly to the element
18 // must sync after call
19 template <typename T, int LENGTH, int NUM_COMP>
20 static __device__ __inline__ void read_1d(const T *devptr, const int compstride, T *sBuffer[NUM_COMP], const int tx) {
21   if (tx < LENGTH) {
22     for (int comp = 0; comp < NUM_COMP; comp++) {
23       sBuffer[comp][tx] = devptr[comp * compstride + tx];
24     }
25   }
26 }
27 
28 //////////////////////////////////////////////////////////////////////////////////////////
29 // write V of a 1D element into global memory from sV[][] --  for all components
30 // the devptr is assumed to point directly to the element
31 template <typename T, int LENGTH, int NUM_COMP>
32 static __device__ __inline__ void write_1d(T *sBuffer[NUM_COMP], T *devptr, const int compstride, const int tx) {
33   if (tx < LENGTH) {
34     for (int comp = 0; comp < NUM_COMP; comp++) {
35       devptr[comp * compstride + tx] = sBuffer[comp][tx];
36     }
37   }
38 }
39 
40 //////////////////////////////////////////////////////////////////////////////////////////
41 // read U of a 2D element into registers rU[][][] --  for all components of a single dim
42 // dU is assumed to be offset by elem-stride and dim-stride
43 // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE]
44 // i_DIM specifies which dimension is being read into in rU
45 // rU_SIZE can be different from P (e.g. MAXP_Q)
46 // sTmp is a shared memory workspace of size P^2
47 template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM>
48 static __device__ __inline__ void readU_2d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) {
49   // read U as a batch P of (1 x P_) vectors
50   // vec 0  : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
51   // vec 1  : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
52   // ...
53   // vec P-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
54   // threads collaboratively read vec0 and then vec1 and so on
55   // but for the kernel, we want
56   // thread 0 to hold all of vec0 in registers, and
57   // thread 1 to hold all of vec1 in registers, and and so on
58   // so we need to transpose
59   for (int comp = 0; comp < NUM_COMP; comp++) {
60     // read from global memory into shared memory
61     if (tx < P) {
62       for (int i = 0; i < P; i++) {
63         sTmp[i * P + tx] = dU[comp * compstride + i * P + tx];
64       }
65     }
66     __syncthreads();
67 
68     if (tx < P) {
69       for (int i = 0; i < P; i++) {
70         rU[i_DIM][comp][i] = sTmp[tx * P + i];
71       }
72     }
73     __syncthreads();
74   }
75 }
76 
77 //////////////////////////////////////////////////////////////////////////////////////////
78 // read V of a 2D element into registers rV[][][] --  for all components of a single dim
79 // dV is assumed to be offset by elem-stride and dim-stride
80 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
81 // i_DIM specifies which dimension is being read into in rV
82 // rV_SIZE can be different from P (e.g. MAXP_Q)
83 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
84 static __device__ __inline__ void readV_2d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
85   if (tx < Q) {
86     for (int comp = 0; comp < NUM_COMP; comp++) {
87       for (int j = 0; j < Q; j++) {
88         rV[i_DIM][comp][j] = dV[comp * compstride + j * Q + tx];
89       }
90     }
91   }
92 }
93 
94 //////////////////////////////////////////////////////////////////////////////////////////
95 // write V of a 2D element from registers rV[][][] to global memory --  for all components of a single dim
96 // dV is assumed to be offset by elem-stride and dim-stride
97 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
98 // i_DIM specifies which dimension is being read from in rV
99 // idim specifies which dimension is being written to in dV
100 // rV_SIZE can be different from P (e.g. MAXP_Q)
101 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
102 static __device__ __inline__ void writeV_2d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
103   if (tx < Q) {
104     for (int comp = 0; comp < NUM_COMP; comp++) {
105       for (int j = 0; j < Q; j++) {
106         dV[comp * compstride + j * Q + tx] = rV[i_DIM][comp][j];
107       }
108     }
109   }
110 }
111 
112 //////////////////////////////////////////////////////////////////////////////////////////
113 // read U of a 3D element into registers rU[][][] --  for all components of a single dim
114 // dU is assumed to be offset by elem-stride and dim-stride
115 // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE]
116 // i_DIM specifies which dimension is being read into in rU
117 // rU_SIZE can be different from P (e.g. MAXP_Q)
118 // sTmp is a shared memory workspace of size P^3
119 template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM>
120 static __device__ __inline__ void readU_3d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) {
121   // read U as a batch P^2 of (1 x P_) vectors
122   // vec 0    : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
123   // vec 1    : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
124   // ...
125   // vec P^2-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
126   // threads collaboratively read vec0 and then vec1 and so on
127   // but for the kernel, we want
128   // thread 0 to hold all of vec0 in registers, and
129   // thread 1 to hold all of vec1 in registers, and and so on
130   // so we need to transpose
131   for (int comp = 0; comp < NUM_COMP; comp++) {
132     // read from global memory into shared memory
133     if (tx < P * P) {
134       for (int i = 0; i < P; i++) {
135         sTmp[i * P * P + tx] = dU[comp * compstride + i * P * P + tx];
136       }
137     }
138     __syncthreads();
139 
140     if (tx < P * P) {
141       for (int i = 0; i < P; i++) {
142         rU[i_DIM][comp][i] = sTmp[tx * P + i];
143       }
144     }
145     __syncthreads();
146   }
147 }
148 
149 //////////////////////////////////////////////////////////////////////////////////////////
150 // read V of a 3D element into registers rV[][][] --  for all components of a single dim
151 // dV is assumed to be offset by elem-stride and dim-stride
152 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
153 // i_DIM specifies which dimension is being read into in rV
154 // rV_SIZE can be different from P (e.g. MAXP_Q)
155 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
156 static __device__ __inline__ void readV_3d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
157   if (tx < Q * Q) {
158     for (int comp = 0; comp < NUM_COMP; comp++) {
159       for (int j = 0; j < Q; j++) {
160         rV[i_DIM][comp][j] = dV[comp * compstride + j * (Q * Q) + tx];
161       }
162     }
163   }
164 }
165 
166 //////////////////////////////////////////////////////////////////////////////////////////
167 // write V of a 3D element from registers rV[][][] to global memory --  for all components of a single dim
168 // dV is assumed to point directly to the element (i.e. already offset by elem-stride)
169 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
170 // i_DIM specifies which dimension is being read from in rV
171 // idim specifies which dimension is being written to in dV
172 // rV_SIZE can be different from P (e.g. MAXP_Q)
173 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
174 static __device__ __inline__ void writeV_3d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
175   if (tx < (Q * Q)) {
176     for (int comp = 0; comp < NUM_COMP; comp++) {
177       for (int j = 0; j < Q; j++) {
178         dV[comp * compstride + j * (Q * Q) + tx] = rV[i_DIM][comp][j];
179       }
180     }
181   }
182 }
183 
184 //////////////////////////////////////////////////////////////////////////////////////////
185 // reads T into shared memory
186 // must sync after call
187 template <int B, int J>
188 static __device__ __inline__ void dread_T_gm2sm(const int tx, const magma_trans_t transT, const CeedScalar *dT, CeedScalar *sT) {
189   if (transT == MagmaNoTrans) {
190     // T is B x J
191     if (tx < B) {
192       for (int i = 0; i < J; i++) {
193         sT[i * B + tx] = dT[i * B + tx];
194       }
195     }
196   } else {
197     // T is J x B
198     if (tx < J) {
199       for (int i = 0; i < B; i++) {
200         sT[tx * B + i] = dT[i * J + tx];
201       }
202     }
203   }
204   // must sync after call
205 }
206 
207 #endif  // CEED_MAGMA_COMMON_TENSOR_H
208