1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED: http://github.com/ceed
7
8 /// @file
9 /// Internal header for MAGMA backend common tensor basis definitions
10 #pragma once
11
12 #include "magma-common-defs.h"
13
14 ////////////////////////////////////////////////////////////////////////////////
15 // read U or V of a 1D element into shared memory sU[][] or sV[][] -- for all components
16 // the devptr is assumed to point directly to the element
17 // must sync after call
18 template <typename T, int LENGTH, int NUM_COMP>
read_1d(const T * devptr,const int compstride,T * sBuffer[NUM_COMP],const int tx)19 static __device__ __inline__ void read_1d(const T *devptr, const int compstride, T *sBuffer[NUM_COMP], const int tx) {
20 if (tx < LENGTH) {
21 for (int comp = 0; comp < NUM_COMP; comp++) {
22 sBuffer[comp][tx] = devptr[comp * compstride + tx];
23 }
24 }
25 }
26
27 ////////////////////////////////////////////////////////////////////////////////
28 // write V of a 1D element into global memory from sV[][] -- for all components
29 // the devptr is assumed to point directly to the element
30 template <typename T, int LENGTH, int NUM_COMP>
write_1d(T * sBuffer[NUM_COMP],T * devptr,const int compstride,const int tx)31 static __device__ __inline__ void write_1d(T *sBuffer[NUM_COMP], T *devptr, const int compstride, const int tx) {
32 if (tx < LENGTH) {
33 for (int comp = 0; comp < NUM_COMP; comp++) {
34 devptr[comp * compstride + tx] = sBuffer[comp][tx];
35 }
36 }
37 }
38
39 ////////////////////////////////////////////////////////////////////////////////
40 // sum into V of a 1D element into global memory from sV[][] -- for all components
41 // the devptr is assumed to point directly to the element
42 template <typename T, int LENGTH, int NUM_COMP>
sum_1d(T * sBuffer[NUM_COMP],T * devptr,const int compstride,const int tx)43 static __device__ __inline__ void sum_1d(T *sBuffer[NUM_COMP], T *devptr, const int compstride, const int tx) {
44 if (tx < LENGTH) {
45 for (int comp = 0; comp < NUM_COMP; comp++) {
46 devptr[comp * compstride + tx] += sBuffer[comp][tx];
47 }
48 }
49 }
50
51 ////////////////////////////////////////////////////////////////////////////////
52 // read U of a 2D element into registers rU[][][] -- for all components of a single dim
53 // dU is assumed to be offset by elem-stride and dim-stride
54 // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE]
55 // i_DIM specifies which dimension is being read into in rU
56 // rU_SIZE can be different from P (e.g. max(P, Q))
57 // sTmp is a shared memory workspace of size P^2
58 template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM>
read_U_2d(const T * dU,const int compstride,T rU[DIM_U][NUM_COMP][rU_SIZE],T * sTmp,const int tx)59 static __device__ __inline__ void read_U_2d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) {
60 // read U as a batch P of (1 x P) vectors
61 // vec 0 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
62 // vec 1 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
63 // ...
64 // vec P-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
65 // threads collaboratively read vec0 and then vec1 and so on
66 // but for the kernel, we want
67 // thread 0 to hold all of vec0 in registers, and
68 // thread 1 to hold all of vec1 in registers, and and so on
69 // so we need to transpose
70 for (int comp = 0; comp < NUM_COMP; comp++) {
71 // read from global memory into shared memory
72 if (tx < P) {
73 for (int i = 0; i < P; i++) {
74 sTmp[i * P + tx] = dU[comp * compstride + i * P + tx];
75 }
76 }
77 __syncthreads();
78
79 if (tx < P) {
80 for (int i = 0; i < P; i++) {
81 rU[i_DIM][comp][i] = sTmp[tx * P + i];
82 }
83 }
84 __syncthreads();
85 }
86 }
87
88 ////////////////////////////////////////////////////////////////////////////////
89 // read V of a 2D element into registers rV[][][] -- for all components of a single dim
90 // dV is assumed to be offset by elem-stride and dim-stride
91 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
92 // i_DIM specifies which dimension is being read into in rV
93 // rV_SIZE can be different from P (e.g. max(P, Q))
94 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
read_V_2d(const T * dV,const int compstride,T rV[DIM_V][NUM_COMP][rV_SIZE],const int tx)95 static __device__ __inline__ void read_V_2d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
96 if (tx < Q) {
97 for (int comp = 0; comp < NUM_COMP; comp++) {
98 for (int j = 0; j < Q; j++) {
99 rV[i_DIM][comp][j] = dV[comp * compstride + j * Q + tx];
100 }
101 }
102 }
103 }
104
105 ////////////////////////////////////////////////////////////////////////////////
106 // write V of a 2D element from registers rV[][][] to global memory -- for all components of a single dim
107 // dV is assumed to be offset by elem-stride and dim-stride
108 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
109 // i_DIM specifies which dimension is being written to in dV
110 // rV_SIZE can be different from P (e.g. max(P, Q))
111 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
write_V_2d(T * dV,const int compstride,T rV[DIM_V][NUM_COMP][rV_SIZE],const int tx)112 static __device__ __inline__ void write_V_2d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
113 if (tx < Q) {
114 for (int comp = 0; comp < NUM_COMP; comp++) {
115 for (int j = 0; j < Q; j++) {
116 dV[comp * compstride + j * Q + tx] = rV[i_DIM][comp][j];
117 }
118 }
119 }
120 }
121
122 ////////////////////////////////////////////////////////////////////////////////
123 // sum into V of a 2D element from registers rV[][][] to global memory -- for all components of a single dim
124 // dV is assumed to be offset by elem-stride and dim-stride
125 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
126 // i_DIM specifies which dimension is being written to in dV
127 // rV_SIZE can be different from P (e.g. max(P, Q))
128 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
sum_V_2d(T * dV,const int compstride,T rV[DIM_V][NUM_COMP][rV_SIZE],const int tx)129 static __device__ __inline__ void sum_V_2d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
130 if (tx < Q) {
131 for (int comp = 0; comp < NUM_COMP; comp++) {
132 for (int j = 0; j < Q; j++) {
133 dV[comp * compstride + j * Q + tx] += rV[i_DIM][comp][j];
134 }
135 }
136 }
137 }
138
139 ////////////////////////////////////////////////////////////////////////////////
140 // read U of a 3D element into registers rU[][][] -- for all components of a single dim
141 // dU is assumed to be offset by elem-stride and dim-stride
142 // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE]
143 // i_DIM specifies which dimension is being read into in rU
144 // rU_SIZE can be different from P (e.g. max(P, Q))
145 // sTmp is a shared memory workspace of size P^3
146 template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM>
read_U_3d(const T * dU,const int compstride,T rU[DIM_U][NUM_COMP][rU_SIZE],T * sTmp,const int tx)147 static __device__ __inline__ void read_U_3d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) {
148 // read U as a batch P^2 of (1 x P_) vectors
149 // vec 0 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
150 // vec 1 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
151 // ...
152 // vec P^2-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory
153 // threads collaboratively read vec0 and then vec1 and so on
154 // but for the kernel, we want
155 // thread 0 to hold all of vec0 in registers, and
156 // thread 1 to hold all of vec1 in registers, and and so on
157 // so we need to transpose
158 for (int comp = 0; comp < NUM_COMP; comp++) {
159 // read from global memory into shared memory
160 if (tx < P * P) {
161 for (int i = 0; i < P; i++) {
162 sTmp[i * P * P + tx] = dU[comp * compstride + i * P * P + tx];
163 }
164 }
165 __syncthreads();
166
167 if (tx < P * P) {
168 for (int i = 0; i < P; i++) {
169 rU[i_DIM][comp][i] = sTmp[tx * P + i];
170 }
171 }
172 __syncthreads();
173 }
174 }
175
176 ////////////////////////////////////////////////////////////////////////////////
177 // read V of a 3D element into registers rV[][][] -- for all components of a single dim
178 // dV is assumed to be offset by elem-stride and dim-stride
179 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
180 // i_DIM specifies which dimension is being read into in rV
181 // rV_SIZE can be different from P (e.g. max(P, Q))
182 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
read_V_3d(const T * dV,const int compstride,T rV[DIM_V][NUM_COMP][rV_SIZE],const int tx)183 static __device__ __inline__ void read_V_3d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
184 if (tx < Q * Q) {
185 for (int comp = 0; comp < NUM_COMP; comp++) {
186 for (int j = 0; j < Q; j++) {
187 rV[i_DIM][comp][j] = dV[comp * compstride + j * (Q * Q) + tx];
188 }
189 }
190 }
191 }
192
193 ////////////////////////////////////////////////////////////////////////////////
194 // write V of a 3D element from registers rV[][][] to global memory -- for all components of a single dim
195 // dV is assumed to point directly to the element (i.e. already offset by elem-stride)
196 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
197 // i_DIM specifies which dimension is being written to in dV
198 // rV_SIZE can be different from P (e.g. max(P, Q))
199 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
write_V_3d(T * dV,const int compstride,T rV[DIM_V][NUM_COMP][rV_SIZE],const int tx)200 static __device__ __inline__ void write_V_3d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
201 if (tx < (Q * Q)) {
202 for (int comp = 0; comp < NUM_COMP; comp++) {
203 for (int j = 0; j < Q; j++) {
204 dV[comp * compstride + j * (Q * Q) + tx] = rV[i_DIM][comp][j];
205 }
206 }
207 }
208 }
209
210 ////////////////////////////////////////////////////////////////////////////////
211 // sum into V of a 3D element from registers rV[][][] to global memory -- for all components of a single dim
212 // dV is assumed to point directly to the element (i.e. already offset by elem-stride)
213 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE]
214 // i_DIM specifies which dimension is being written to in dV
215 // rV_SIZE can be different from P (e.g. max(P, Q))
216 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM>
sum_V_3d(T * dV,const int compstride,T rV[DIM_V][NUM_COMP][rV_SIZE],const int tx)217 static __device__ __inline__ void sum_V_3d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) {
218 if (tx < (Q * Q)) {
219 for (int comp = 0; comp < NUM_COMP; comp++) {
220 for (int j = 0; j < Q; j++) {
221 dV[comp * compstride + j * (Q * Q) + tx] += rV[i_DIM][comp][j];
222 }
223 }
224 }
225 }
226
227 ////////////////////////////////////////////////////////////////////////////////
228 // reads T (no-trans) into shared memory
229 // T is B x J
230 // must sync after call
231 template <int B, int J>
read_T_notrans_gm2sm(const int tx,const CeedScalar * dT,CeedScalar * sT)232 static __device__ __inline__ void read_T_notrans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) {
233 if (tx < B) {
234 for (int i = 0; i < J; i++) {
235 sT[i * B + tx] = dT[i * B + tx];
236 }
237 }
238 // must sync after call
239 }
240
241 ////////////////////////////////////////////////////////////////////////////////
242 // reads T (trans) into shared memory
243 // T is J x B
244 // must sync after call
245 template <int B, int J>
read_T_trans_gm2sm(const int tx,const CeedScalar * dT,CeedScalar * sT)246 static __device__ __inline__ void read_T_trans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) {
247 if (tx < J) {
248 for (int i = 0; i < B; i++) {
249 sT[tx * B + i] = dT[i * J + tx];
250 }
251 }
252 // must sync after call
253 }
254