1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 /// @file 9 /// Internal header for MAGMA backend common tensor basis definitions 10 #pragma once 11 12 #include "magma-common-defs.h" 13 14 //////////////////////////////////////////////////////////////////////////////// 15 // read U or V of a 1D element into shared memory sU[][] or sV[][] -- for all components 16 // the devptr is assumed to point directly to the element 17 // must sync after call 18 template <typename T, int LENGTH, int NUM_COMP> 19 static __device__ __inline__ void read_1d(const T *devptr, const int compstride, T *sBuffer[NUM_COMP], const int tx) { 20 if (tx < LENGTH) { 21 for (int comp = 0; comp < NUM_COMP; comp++) { 22 sBuffer[comp][tx] = devptr[comp * compstride + tx]; 23 } 24 } 25 } 26 27 //////////////////////////////////////////////////////////////////////////////// 28 // write V of a 1D element into global memory from sV[][] -- for all components 29 // the devptr is assumed to point directly to the element 30 template <typename T, int LENGTH, int NUM_COMP> 31 static __device__ __inline__ void write_1d(T *sBuffer[NUM_COMP], T *devptr, const int compstride, const int tx) { 32 if (tx < LENGTH) { 33 for (int comp = 0; comp < NUM_COMP; comp++) { 34 devptr[comp * compstride + tx] = sBuffer[comp][tx]; 35 } 36 } 37 } 38 39 //////////////////////////////////////////////////////////////////////////////// 40 // sum into V of a 1D element into global memory from sV[][] -- for all components 41 // the devptr is assumed to point directly to the element 42 template <typename T, int LENGTH, int NUM_COMP> 43 static __device__ __inline__ void sum_1d(T *sBuffer[NUM_COMP], T *devptr, const int compstride, const int tx) { 44 if (tx < LENGTH) { 45 for (int comp = 0; comp < NUM_COMP; comp++) { 46 devptr[comp * compstride + tx] += sBuffer[comp][tx]; 47 } 48 } 49 } 50 51 //////////////////////////////////////////////////////////////////////////////// 52 // read U of a 2D element into registers rU[][][] -- for all components of a single dim 53 // dU is assumed to be offset by elem-stride and dim-stride 54 // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE] 55 // i_DIM specifies which dimension is being read into in rU 56 // rU_SIZE can be different from P (e.g. max(P, Q)) 57 // sTmp is a shared memory workspace of size P^2 58 template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM> 59 static __device__ __inline__ void read_U_2d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) { 60 // read U as a batch P of (1 x P) vectors 61 // vec 0 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 62 // vec 1 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 63 // ... 64 // vec P-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 65 // threads collaboratively read vec0 and then vec1 and so on 66 // but for the kernel, we want 67 // thread 0 to hold all of vec0 in registers, and 68 // thread 1 to hold all of vec1 in registers, and and so on 69 // so we need to transpose 70 for (int comp = 0; comp < NUM_COMP; comp++) { 71 // read from global memory into shared memory 72 if (tx < P) { 73 for (int i = 0; i < P; i++) { 74 sTmp[i * P + tx] = dU[comp * compstride + i * P + tx]; 75 } 76 } 77 __syncthreads(); 78 79 if (tx < P) { 80 for (int i = 0; i < P; i++) { 81 rU[i_DIM][comp][i] = sTmp[tx * P + i]; 82 } 83 } 84 __syncthreads(); 85 } 86 } 87 88 //////////////////////////////////////////////////////////////////////////////// 89 // read V of a 2D element into registers rV[][][] -- for all components of a single dim 90 // dV is assumed to be offset by elem-stride and dim-stride 91 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 92 // i_DIM specifies which dimension is being read into in rV 93 // rV_SIZE can be different from P (e.g. max(P, Q)) 94 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 95 static __device__ __inline__ void read_V_2d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 96 if (tx < Q) { 97 for (int comp = 0; comp < NUM_COMP; comp++) { 98 for (int j = 0; j < Q; j++) { 99 rV[i_DIM][comp][j] = dV[comp * compstride + j * Q + tx]; 100 } 101 } 102 } 103 } 104 105 //////////////////////////////////////////////////////////////////////////////// 106 // write V of a 2D element from registers rV[][][] to global memory -- for all components of a single dim 107 // dV is assumed to be offset by elem-stride and dim-stride 108 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 109 // i_DIM specifies which dimension is being written to in dV 110 // rV_SIZE can be different from P (e.g. max(P, Q)) 111 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 112 static __device__ __inline__ void write_V_2d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 113 if (tx < Q) { 114 for (int comp = 0; comp < NUM_COMP; comp++) { 115 for (int j = 0; j < Q; j++) { 116 dV[comp * compstride + j * Q + tx] = rV[i_DIM][comp][j]; 117 } 118 } 119 } 120 } 121 122 //////////////////////////////////////////////////////////////////////////////// 123 // sum into V of a 2D element from registers rV[][][] to global memory -- for all components of a single dim 124 // dV is assumed to be offset by elem-stride and dim-stride 125 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 126 // i_DIM specifies which dimension is being written to in dV 127 // rV_SIZE can be different from P (e.g. max(P, Q)) 128 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 129 static __device__ __inline__ void sum_V_2d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 130 if (tx < Q) { 131 for (int comp = 0; comp < NUM_COMP; comp++) { 132 for (int j = 0; j < Q; j++) { 133 dV[comp * compstride + j * Q + tx] += rV[i_DIM][comp][j]; 134 } 135 } 136 } 137 } 138 139 //////////////////////////////////////////////////////////////////////////////// 140 // read U of a 3D element into registers rU[][][] -- for all components of a single dim 141 // dU is assumed to be offset by elem-stride and dim-stride 142 // register is assumed to be rU[DIM_U][NUM_COMP][rU_SIZE] 143 // i_DIM specifies which dimension is being read into in rU 144 // rU_SIZE can be different from P (e.g. max(P, Q)) 145 // sTmp is a shared memory workspace of size P^3 146 template <typename T, int P, int DIM_U, int NUM_COMP, int rU_SIZE, int i_DIM> 147 static __device__ __inline__ void read_U_3d(const T *dU, const int compstride, T rU[DIM_U][NUM_COMP][rU_SIZE], T *sTmp, const int tx) { 148 // read U as a batch P^2 of (1 x P_) vectors 149 // vec 0 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 150 // vec 1 : [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 151 // ... 152 // vec P^2-1: [u0, u1, u2, ... u_(P-1)] -- contiguous in memory 153 // threads collaboratively read vec0 and then vec1 and so on 154 // but for the kernel, we want 155 // thread 0 to hold all of vec0 in registers, and 156 // thread 1 to hold all of vec1 in registers, and and so on 157 // so we need to transpose 158 for (int comp = 0; comp < NUM_COMP; comp++) { 159 // read from global memory into shared memory 160 if (tx < P * P) { 161 for (int i = 0; i < P; i++) { 162 sTmp[i * P * P + tx] = dU[comp * compstride + i * P * P + tx]; 163 } 164 } 165 __syncthreads(); 166 167 if (tx < P * P) { 168 for (int i = 0; i < P; i++) { 169 rU[i_DIM][comp][i] = sTmp[tx * P + i]; 170 } 171 } 172 __syncthreads(); 173 } 174 } 175 176 //////////////////////////////////////////////////////////////////////////////// 177 // read V of a 3D element into registers rV[][][] -- for all components of a single dim 178 // dV is assumed to be offset by elem-stride and dim-stride 179 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 180 // i_DIM specifies which dimension is being read into in rV 181 // rV_SIZE can be different from P (e.g. max(P, Q)) 182 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 183 static __device__ __inline__ void read_V_3d(const T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 184 if (tx < Q * Q) { 185 for (int comp = 0; comp < NUM_COMP; comp++) { 186 for (int j = 0; j < Q; j++) { 187 rV[i_DIM][comp][j] = dV[comp * compstride + j * (Q * Q) + tx]; 188 } 189 } 190 } 191 } 192 193 //////////////////////////////////////////////////////////////////////////////// 194 // write V of a 3D element from registers rV[][][] to global memory -- for all components of a single dim 195 // dV is assumed to point directly to the element (i.e. already offset by elem-stride) 196 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 197 // i_DIM specifies which dimension is being written to in dV 198 // rV_SIZE can be different from P (e.g. max(P, Q)) 199 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 200 static __device__ __inline__ void write_V_3d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 201 if (tx < (Q * Q)) { 202 for (int comp = 0; comp < NUM_COMP; comp++) { 203 for (int j = 0; j < Q; j++) { 204 dV[comp * compstride + j * (Q * Q) + tx] = rV[i_DIM][comp][j]; 205 } 206 } 207 } 208 } 209 210 //////////////////////////////////////////////////////////////////////////////// 211 // sum into V of a 3D element from registers rV[][][] to global memory -- for all components of a single dim 212 // dV is assumed to point directly to the element (i.e. already offset by elem-stride) 213 // register is assumed to be rV[DIM_V][NUM_COMP][rV_SIZE] 214 // i_DIM specifies which dimension is being written to in dV 215 // rV_SIZE can be different from P (e.g. max(P, Q)) 216 template <typename T, int Q, int DIM_V, int NUM_COMP, int rV_SIZE, int i_DIM> 217 static __device__ __inline__ void sum_V_3d(T *dV, const int compstride, T rV[DIM_V][NUM_COMP][rV_SIZE], const int tx) { 218 if (tx < (Q * Q)) { 219 for (int comp = 0; comp < NUM_COMP; comp++) { 220 for (int j = 0; j < Q; j++) { 221 dV[comp * compstride + j * (Q * Q) + tx] += rV[i_DIM][comp][j]; 222 } 223 } 224 } 225 } 226 227 //////////////////////////////////////////////////////////////////////////////// 228 // reads T (no-trans) into shared memory 229 // T is B x J 230 // must sync after call 231 template <int B, int J> 232 static __device__ __inline__ void read_T_notrans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) { 233 if (tx < B) { 234 for (int i = 0; i < J; i++) { 235 sT[i * B + tx] = dT[i * B + tx]; 236 } 237 } 238 // must sync after call 239 } 240 241 //////////////////////////////////////////////////////////////////////////////// 242 // reads T (trans) into shared memory 243 // T is J x B 244 // must sync after call 245 template <int B, int J> 246 static __device__ __inline__ void read_T_trans_gm2sm(const int tx, const CeedScalar *dT, CeedScalar *sT) { 247 if (tx < J) { 248 for (int i = 0; i < B; i++) { 249 sT[tx * B + i] = dT[i * J + tx]; 250 } 251 } 252 // must sync after call 253 } 254