1 // Copyright (c) 2017-2025, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 /// @file 9 /// Internal header for MAGMA tensor basis interpolation in 1D 10 #include "magma-common-tensor.h" 11 12 // macros to abstract access of shared memory and reg. file 13 #define sT(i, j) sT[(j) * P + (i)] 14 15 //////////////////////////////////////////////////////////////////////////////// 16 // interp basis action (1D) 17 template <typename T, int DIM, int NUM_COMP, int P, int Q> 18 static __device__ __inline__ void magma_interp_1d_device(const T *sT, T *sU[NUM_COMP], T *sV[NUM_COMP], const int tx) { 19 // Assumptions 20 // 1. 1D threads of size max(P,Q) 21 // 2. sU[i] is 1xP: in shared memory 22 // 3. sV[i] is 1xQ: in shared memory 23 // 4. P_roduct per component is one row (1xP) times T matrix (PxQ) => one row (1xQ) 24 // 5. Each thread computes one entry in sV[i] 25 // 6. Must sync before and after call 26 // 7. Note that the layout for U and V is different from 2D/3D problem 27 28 if (tx < Q) { 29 for (int comp = 0; comp < NUM_COMP; comp++) { 30 T rv = 0.0; 31 for (int i = 0; i < P; i++) { 32 rv += sU[comp][i] * sT(i, tx); // sT[tx * P + i]; 33 } 34 sV[comp][tx] = rv; 35 } 36 } 37 } 38 39 //////////////////////////////////////////////////////////////////////////////// 40 extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q, MAGMA_MAXTHREADS_1D)) __global__ 41 void magma_interpn_1d_kernel(const CeedScalar *dT, const CeedScalar *dU, const int estrdU, const int cstrdU, CeedScalar *dV, const int estrdV, 42 const int cstrdV, const int nelem) { 43 MAGMA_DEVICE_SHARED(CeedScalar, shared_data) 44 45 const int tx = threadIdx.x; 46 const int ty = threadIdx.y; 47 const int elem_id = (blockIdx.x * blockDim.y) + ty; 48 49 if (elem_id >= nelem) return; 50 51 CeedScalar *sU[BASIS_NUM_COMP]; 52 CeedScalar *sV[BASIS_NUM_COMP]; 53 54 // shift global memory pointers by elem stride 55 dU += elem_id * estrdU; 56 dV += elem_id * estrdV; 57 58 // assign shared memory pointers 59 CeedScalar *sT = (CeedScalar *)shared_data; 60 CeedScalar *sW = sT + BASIS_P * BASIS_Q; 61 sU[0] = sW + ty * BASIS_NUM_COMP * (BASIS_P + BASIS_Q); 62 sV[0] = sU[0] + (BASIS_NUM_COMP * 1 * BASIS_P); 63 for (int comp = 1; comp < BASIS_NUM_COMP; comp++) { 64 sU[comp] = sU[comp - 1] + (1 * BASIS_P); 65 sV[comp] = sV[comp - 1] + (1 * BASIS_Q); 66 } 67 68 // read T 69 if (ty == 0) { 70 read_T_notrans_gm2sm<BASIS_P, BASIS_Q>(tx, dT, sT); 71 } 72 73 // read U 74 read_1d<CeedScalar, BASIS_P, BASIS_NUM_COMP>(dU, cstrdU, sU, tx); 75 76 __syncthreads(); 77 magma_interp_1d_device<CeedScalar, BASIS_DIM, BASIS_NUM_COMP, BASIS_P, BASIS_Q>(sT, sU, sV, tx); 78 __syncthreads(); 79 80 // write V 81 write_1d<CeedScalar, BASIS_Q, BASIS_NUM_COMP>(sV, dV, cstrdV, tx); 82 } 83 84 //////////////////////////////////////////////////////////////////////////////// 85 extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q, MAGMA_MAXTHREADS_1D)) __global__ 86 void magma_interpt_1d_kernel(const CeedScalar *dT, const CeedScalar *dU, const int estrdU, const int cstrdU, CeedScalar *dV, const int estrdV, 87 const int cstrdV, const int nelem) { 88 MAGMA_DEVICE_SHARED(CeedScalar, shared_data) 89 90 const int tx = threadIdx.x; 91 const int ty = threadIdx.y; 92 const int elem_id = (blockIdx.x * blockDim.y) + ty; 93 94 if (elem_id >= nelem) return; 95 96 CeedScalar *sU[BASIS_NUM_COMP]; 97 CeedScalar *sV[BASIS_NUM_COMP]; 98 99 // shift global memory pointers by elem stride 100 dU += elem_id * estrdU; 101 dV += elem_id * estrdV; 102 103 // assign shared memory pointers 104 CeedScalar *sT = (CeedScalar *)shared_data; 105 CeedScalar *sW = sT + BASIS_Q * BASIS_P; 106 sU[0] = sW + ty * BASIS_NUM_COMP * (BASIS_Q + BASIS_P); 107 sV[0] = sU[0] + (BASIS_NUM_COMP * 1 * BASIS_Q); 108 for (int comp = 1; comp < BASIS_NUM_COMP; comp++) { 109 sU[comp] = sU[comp - 1] + (1 * BASIS_Q); 110 sV[comp] = sV[comp - 1] + (1 * BASIS_P); 111 } 112 113 // read T 114 if (ty == 0) { 115 read_T_trans_gm2sm<BASIS_Q, BASIS_P>(tx, dT, sT); 116 } 117 118 // read U 119 read_1d<CeedScalar, BASIS_Q, BASIS_NUM_COMP>(dU, cstrdU, sU, tx); 120 121 __syncthreads(); 122 magma_interp_1d_device<CeedScalar, BASIS_DIM, BASIS_NUM_COMP, BASIS_Q, BASIS_P>(sT, sU, sV, tx); 123 __syncthreads(); 124 125 // write V 126 write_1d<CeedScalar, BASIS_P, BASIS_NUM_COMP>(sV, dV, cstrdV, tx); 127 } 128 129 //////////////////////////////////////////////////////////////////////////////// 130 extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q, MAGMA_MAXTHREADS_1D)) __global__ 131 void magma_interpta_1d_kernel(const CeedScalar *dT, const CeedScalar *dU, const int estrdU, const int cstrdU, CeedScalar *dV, const int estrdV, 132 const int cstrdV, const int nelem) { 133 MAGMA_DEVICE_SHARED(CeedScalar, shared_data) 134 135 const int tx = threadIdx.x; 136 const int ty = threadIdx.y; 137 const int elem_id = (blockIdx.x * blockDim.y) + ty; 138 139 if (elem_id >= nelem) return; 140 141 CeedScalar *sU[BASIS_NUM_COMP]; 142 CeedScalar *sV[BASIS_NUM_COMP]; 143 144 // shift global memory pointers by elem stride 145 dU += elem_id * estrdU; 146 dV += elem_id * estrdV; 147 148 // assign shared memory pointers 149 CeedScalar *sT = (CeedScalar *)shared_data; 150 CeedScalar *sW = sT + BASIS_Q * BASIS_P; 151 sU[0] = sW + ty * BASIS_NUM_COMP * (BASIS_Q + BASIS_P); 152 sV[0] = sU[0] + (BASIS_NUM_COMP * 1 * BASIS_Q); 153 for (int comp = 1; comp < BASIS_NUM_COMP; comp++) { 154 sU[comp] = sU[comp - 1] + (1 * BASIS_Q); 155 sV[comp] = sV[comp - 1] + (1 * BASIS_P); 156 } 157 158 // read T 159 if (ty == 0) { 160 read_T_trans_gm2sm<BASIS_Q, BASIS_P>(tx, dT, sT); 161 } 162 163 // read U 164 read_1d<CeedScalar, BASIS_Q, BASIS_NUM_COMP>(dU, cstrdU, sU, tx); 165 166 __syncthreads(); 167 magma_interp_1d_device<CeedScalar, BASIS_DIM, BASIS_NUM_COMP, BASIS_Q, BASIS_P>(sT, sU, sV, tx); 168 __syncthreads(); 169 170 // sum into V 171 sum_1d<CeedScalar, BASIS_P, BASIS_NUM_COMP>(sV, dV, cstrdV, tx); 172 } 173