1bd882c8aSJames Wright // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3bd882c8aSJames Wright // 4bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause 5bd882c8aSJames Wright // 6bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed 7bd882c8aSJames Wright 8bd882c8aSJames Wright /// @file 9bd882c8aSJames Wright /// Internal header for SYCL shared memory tensor product basis 10*94b7b29bSJeremy L Thompson #ifndef CEED_SYCL_SHARED_BASIS_TENSOR_H 11*94b7b29bSJeremy L Thompson #define CEED_SYCL_SHARED_BASIS_TENSOR_H 12bd882c8aSJames Wright 13bd882c8aSJames Wright #include <ceed.h> 14bd882c8aSJames Wright 15bd882c8aSJames Wright #include "sycl-shared-basis-read-write-templates.h" 16bd882c8aSJames Wright #include "sycl-shared-basis-tensor-templates.h" 17bd882c8aSJames Wright 18bd882c8aSJames Wright // 19bd882c8aSJames Wright // BASIS_NUM_NODES = CeedIntPow(BASIS_P_1D,DIM) 20bd882c8aSJames Wright // BASIS_NUM_QPTS = CeedIntPow(BASIS_Q_1D,DIM) 21bd882c8aSJames Wright 22bd882c8aSJames Wright //------------------------------------------------------------------------------ 23bd882c8aSJames Wright // Interp kernel by dim 24bd882c8aSJames Wright //------------------------------------------------------------------------------ 25bd882c8aSJames Wright kernel void Interp(const CeedInt num_elem, global const CeedScalar* restrict d_interp_1d, global const CeedScalar* restrict d_U, 26bd882c8aSJames Wright global CeedScalar* restrict d_V) { 27bd882c8aSJames Wright local CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D]; 28bd882c8aSJames Wright private 29bd882c8aSJames Wright CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)]; 30bd882c8aSJames Wright private 31bd882c8aSJames Wright CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)]; 32bd882c8aSJames Wright 33bd882c8aSJames Wright local CeedScalar scratch[BASIS_INTERP_SCRATCH_SIZE]; 34bd882c8aSJames Wright local CeedScalar* elem_scratch = scratch + get_local_id(2) * T_1D * (BASIS_DIM > 1 ? T_1D : 1); 35bd882c8aSJames Wright 36bd882c8aSJames Wright loadMatrix(BASIS_P_1D * BASIS_Q_1D, d_interp_1d, s_B); 37bd882c8aSJames Wright work_group_barrier(CLK_LOCAL_MEM_FENCE); 38bd882c8aSJames Wright 39bd882c8aSJames Wright if (BASIS_DIM == 1) { 40bd882c8aSJames Wright ReadElementStrided1d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 41bd882c8aSJames Wright Interp1d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 42bd882c8aSJames Wright WriteElementStrided1d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 43bd882c8aSJames Wright 44bd882c8aSJames Wright } else if (BASIS_DIM == 2) { 45bd882c8aSJames Wright ReadElementStrided2d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 46bd882c8aSJames Wright InterpTensor2d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 47bd882c8aSJames Wright WriteElementStrided2d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 48bd882c8aSJames Wright 49bd882c8aSJames Wright } else if (BASIS_DIM == 3) { 50bd882c8aSJames Wright ReadElementStrided3d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 51bd882c8aSJames Wright InterpTensor3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 52bd882c8aSJames Wright WriteElementStrided3d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 53bd882c8aSJames Wright } 54bd882c8aSJames Wright } 55bd882c8aSJames Wright 56bd882c8aSJames Wright kernel void InterpTranspose(const CeedInt num_elem, global const CeedScalar* restrict d_interp_1d, global const CeedScalar* restrict d_U, 57bd882c8aSJames Wright global CeedScalar* restrict d_V) { 58bd882c8aSJames Wright // local size: 59bd882c8aSJames Wright // 1d: elems_per_block * T_1d 60bd882c8aSJames Wright // 2d,3d: elems_per_block * T_1d * T_1d 61bd882c8aSJames Wright local CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D]; 62bd882c8aSJames Wright private 63bd882c8aSJames Wright CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)]; 64bd882c8aSJames Wright private 65bd882c8aSJames Wright CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)]; 66bd882c8aSJames Wright 67bd882c8aSJames Wright local CeedScalar scratch[BASIS_INTERP_SCRATCH_SIZE]; 68bd882c8aSJames Wright local CeedScalar* elem_scratch = scratch + get_local_id(2) * T_1D * (BASIS_DIM > 1 ? T_1D : 1); 69bd882c8aSJames Wright 70bd882c8aSJames Wright loadMatrix(BASIS_P_1D * BASIS_Q_1D, d_interp_1d, s_B); 71bd882c8aSJames Wright work_group_barrier(CLK_LOCAL_MEM_FENCE); 72bd882c8aSJames Wright 73bd882c8aSJames Wright if (BASIS_DIM == 1) { 74bd882c8aSJames Wright ReadElementStrided1d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 75bd882c8aSJames Wright InterpTranspose1d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 76bd882c8aSJames Wright WriteElementStrided1d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 77bd882c8aSJames Wright 78bd882c8aSJames Wright } else if (BASIS_DIM == 2) { 79bd882c8aSJames Wright ReadElementStrided2d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 80bd882c8aSJames Wright InterpTransposeTensor2d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 81bd882c8aSJames Wright WriteElementStrided2d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 82bd882c8aSJames Wright 83bd882c8aSJames Wright } else if (BASIS_DIM == 3) { 84bd882c8aSJames Wright ReadElementStrided3d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 85bd882c8aSJames Wright InterpTransposeTensor3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 86bd882c8aSJames Wright WriteElementStrided3d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 87bd882c8aSJames Wright } 88bd882c8aSJames Wright } 89bd882c8aSJames Wright 90bd882c8aSJames Wright //------------------------------------------------------------------------------ 91bd882c8aSJames Wright // Grad kernel by dim 92bd882c8aSJames Wright //------------------------------------------------------------------------------ 93bd882c8aSJames Wright kernel void Grad(const CeedInt num_elem, global const CeedScalar* restrict d_interp_1d, global const CeedScalar* restrict d_grad_1d, 94bd882c8aSJames Wright global const CeedScalar* restrict d_U, global CeedScalar* restrict d_V) { 95bd882c8aSJames Wright local CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D]; // Todo, don't allocate s_B for dimension 1 96bd882c8aSJames Wright local CeedScalar s_G[BASIS_Q_1D * (BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D)]; 97bd882c8aSJames Wright 98bd882c8aSJames Wright private 99bd882c8aSJames Wright CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)]; 100bd882c8aSJames Wright private 101bd882c8aSJames Wright CeedScalar r_V[BASIS_NUM_COMP * BASIS_DIM * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)]; 102bd882c8aSJames Wright 103bd882c8aSJames Wright local CeedScalar scratch[BASIS_GRAD_SCRATCH_SIZE]; 104bd882c8aSJames Wright local CeedScalar* elem_scratch = scratch + get_local_id(2) * T_1D * (BASIS_DIM > 1 ? T_1D : 1); 105bd882c8aSJames Wright 106bd882c8aSJames Wright loadMatrix(BASIS_P_1D * BASIS_Q_1D, d_interp_1d, s_B); 107bd882c8aSJames Wright loadMatrix(BASIS_Q_1D * (BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D), d_grad_1d, s_G); 108bd882c8aSJames Wright work_group_barrier(CLK_LOCAL_MEM_FENCE); 109bd882c8aSJames Wright 110bd882c8aSJames Wright if (BASIS_DIM == 1) { 111bd882c8aSJames Wright ReadElementStrided1d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 112bd882c8aSJames Wright Grad1d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_G, r_V, elem_scratch); 113bd882c8aSJames Wright WriteElementStrided1d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 114bd882c8aSJames Wright 115bd882c8aSJames Wright } else if (BASIS_DIM == 2) { 116bd882c8aSJames Wright ReadElementStrided2d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 117bd882c8aSJames Wright GradTensor2d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 118bd882c8aSJames Wright WriteElementStrided2d(BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 119bd882c8aSJames Wright 120bd882c8aSJames Wright } else if (BASIS_DIM == 3) { 121bd882c8aSJames Wright ReadElementStrided3d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 122bd882c8aSJames Wright if (BASIS_HAS_COLLOCATED_GRAD) GradTensorCollocated3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 123bd882c8aSJames Wright else GradTensor3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 124bd882c8aSJames Wright WriteElementStrided3d(BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 125bd882c8aSJames Wright } 126bd882c8aSJames Wright } 127bd882c8aSJames Wright 128bd882c8aSJames Wright kernel void GradTranspose(const CeedInt num_elem, global const CeedScalar* restrict d_interp_1d, global const CeedScalar* restrict d_grad_1d, 129bd882c8aSJames Wright global const CeedScalar* restrict d_U, global CeedScalar* restrict d_V) { 130bd882c8aSJames Wright local CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D]; // Todo, don't allocate s_B for dimension 1 131bd882c8aSJames Wright local CeedScalar s_G[BASIS_Q_1D * (BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D)]; 132bd882c8aSJames Wright 133bd882c8aSJames Wright private 134bd882c8aSJames Wright CeedScalar r_U[BASIS_NUM_COMP * BASIS_DIM * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)]; 135bd882c8aSJames Wright private 136bd882c8aSJames Wright CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)]; 137bd882c8aSJames Wright 138bd882c8aSJames Wright local CeedScalar scratch[BASIS_GRAD_SCRATCH_SIZE]; 139bd882c8aSJames Wright local CeedScalar* elem_scratch = scratch + get_local_id(2) * T_1D * (BASIS_DIM > 1 ? T_1D : 1); 140bd882c8aSJames Wright 141bd882c8aSJames Wright loadMatrix(BASIS_P_1D * BASIS_Q_1D, d_interp_1d, s_B); 142bd882c8aSJames Wright loadMatrix(BASIS_Q_1D * (BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D), d_grad_1d, s_G); 143bd882c8aSJames Wright work_group_barrier(CLK_LOCAL_MEM_FENCE); 144bd882c8aSJames Wright 145bd882c8aSJames Wright if (BASIS_DIM == 1) { 146bd882c8aSJames Wright ReadElementStrided1d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 147bd882c8aSJames Wright GradTranspose1d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_G, r_V, elem_scratch); 148bd882c8aSJames Wright WriteElementStrided1d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 149bd882c8aSJames Wright 150bd882c8aSJames Wright } else if (BASIS_DIM == 2) { 151bd882c8aSJames Wright ReadElementStrided2d(BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 152bd882c8aSJames Wright GradTransposeTensor2d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 153bd882c8aSJames Wright WriteElementStrided2d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 154bd882c8aSJames Wright 155bd882c8aSJames Wright } else if (BASIS_DIM == 3) { 156bd882c8aSJames Wright ReadElementStrided3d(BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 157bd882c8aSJames Wright if (BASIS_HAS_COLLOCATED_GRAD) GradTransposeTensorCollocated3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 158bd882c8aSJames Wright else GradTransposeTensor3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 159bd882c8aSJames Wright WriteElementStrided3d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 160bd882c8aSJames Wright } 161bd882c8aSJames Wright } 162bd882c8aSJames Wright 163bd882c8aSJames Wright //------------------------------------------------------------------------------ 164bd882c8aSJames Wright // Weight kernels by dim 165bd882c8aSJames Wright //------------------------------------------------------------------------------ 166bd882c8aSJames Wright kernel void Weight(const CeedInt num_elem, global const CeedScalar* restrict q_weight_1d, global CeedScalar* restrict d_W) { 167bd882c8aSJames Wright private 168bd882c8aSJames Wright CeedScalar r_W[BASIS_DIM > 2 ? BASIS_Q_1D : 1]; 169bd882c8aSJames Wright 170bd882c8aSJames Wright // void prefetch(q_weight_1d,BASIS_Q_1D); 171bd882c8aSJames Wright 172bd882c8aSJames Wright if (BASIS_DIM == 1) { 173bd882c8aSJames Wright Weight1d(BASIS_Q_1D, q_weight_1d, r_W); 174bd882c8aSJames Wright WriteElementStrided1d(1, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_W, d_W); 175bd882c8aSJames Wright 176bd882c8aSJames Wright } else if (BASIS_DIM == 2) { 177bd882c8aSJames Wright WeightTensor2d(BASIS_Q_1D, q_weight_1d, r_W); 178bd882c8aSJames Wright WriteElementStrided2d(1, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_W, d_W); 179bd882c8aSJames Wright 180bd882c8aSJames Wright } else if (BASIS_DIM == 3) { 181bd882c8aSJames Wright WeightTensor3d(BASIS_Q_1D, q_weight_1d, r_W); 182bd882c8aSJames Wright WriteElementStrided3d(1, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_W, d_W); 183bd882c8aSJames Wright } 184bd882c8aSJames Wright } 185bd882c8aSJames Wright 186bd882c8aSJames Wright //------------------------------------------------------------------------------ 187bd882c8aSJames Wright 188*94b7b29bSJeremy L Thompson #endif // CEED_SYCL_SHARED_BASIS_TENSOR_H 189