15aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3bd882c8aSJames Wright // 4bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause 5bd882c8aSJames Wright // 6bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed 7bd882c8aSJames Wright 8bd882c8aSJames Wright /// @file 9bd882c8aSJames Wright /// Internal header for SYCL shared memory tensor product basis 10bd882c8aSJames Wright 11*c0b5abf0SJeremy L Thompson #include <ceed/types.h> 12bd882c8aSJames Wright 13bd882c8aSJames Wright #include "sycl-shared-basis-read-write-templates.h" 14bd882c8aSJames Wright #include "sycl-shared-basis-tensor-templates.h" 15bd882c8aSJames Wright 16bd882c8aSJames Wright // 17bd882c8aSJames Wright // BASIS_NUM_NODES = CeedIntPow(BASIS_P_1D,DIM) 18bd882c8aSJames Wright // BASIS_NUM_QPTS = CeedIntPow(BASIS_Q_1D,DIM) 19bd882c8aSJames Wright 20bd882c8aSJames Wright //------------------------------------------------------------------------------ 21bd882c8aSJames Wright // Interp kernel by dim 22bd882c8aSJames Wright //------------------------------------------------------------------------------ 23bd882c8aSJames Wright kernel void Interp(const CeedInt num_elem, global const CeedScalar *restrict d_interp_1d, global const CeedScalar *restrict d_U, 24bd882c8aSJames Wright global CeedScalar *restrict d_V) { 25bd882c8aSJames Wright local CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D]; 26bd882c8aSJames Wright private 27bd882c8aSJames Wright CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)]; 28bd882c8aSJames Wright private 29bd882c8aSJames Wright CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)]; 30bd882c8aSJames Wright 31bd882c8aSJames Wright local CeedScalar scratch[BASIS_INTERP_SCRATCH_SIZE]; 32bd882c8aSJames Wright local CeedScalar *elem_scratch = scratch + get_local_id(2) * T_1D * (BASIS_DIM > 1 ? T_1D : 1); 33bd882c8aSJames Wright 34bd882c8aSJames Wright loadMatrix(BASIS_P_1D * BASIS_Q_1D, d_interp_1d, s_B); 35bd882c8aSJames Wright work_group_barrier(CLK_LOCAL_MEM_FENCE); 36bd882c8aSJames Wright 37bd882c8aSJames Wright if (BASIS_DIM == 1) { 38bd882c8aSJames Wright ReadElementStrided1d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 39bd882c8aSJames Wright Interp1d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 40bd882c8aSJames Wright WriteElementStrided1d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 41bd882c8aSJames Wright 42bd882c8aSJames Wright } else if (BASIS_DIM == 2) { 43bd882c8aSJames Wright ReadElementStrided2d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 44bd882c8aSJames Wright InterpTensor2d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 45bd882c8aSJames Wright WriteElementStrided2d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 46bd882c8aSJames Wright 47bd882c8aSJames Wright } else if (BASIS_DIM == 3) { 48bd882c8aSJames Wright ReadElementStrided3d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 49bd882c8aSJames Wright InterpTensor3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 50bd882c8aSJames Wright WriteElementStrided3d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 51bd882c8aSJames Wright } 52bd882c8aSJames Wright } 53bd882c8aSJames Wright 54bd882c8aSJames Wright kernel void InterpTranspose(const CeedInt num_elem, global const CeedScalar *restrict d_interp_1d, global const CeedScalar *restrict d_U, 55bd882c8aSJames Wright global CeedScalar *restrict d_V) { 56bd882c8aSJames Wright // local size: 57bd882c8aSJames Wright // 1d: elems_per_block * T_1d 58bd882c8aSJames Wright // 2d,3d: elems_per_block * T_1d * T_1d 59bd882c8aSJames Wright local CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D]; 60bd882c8aSJames Wright private 61bd882c8aSJames Wright CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)]; 62bd882c8aSJames Wright private 63bd882c8aSJames Wright CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)]; 64bd882c8aSJames Wright 65bd882c8aSJames Wright local CeedScalar scratch[BASIS_INTERP_SCRATCH_SIZE]; 66bd882c8aSJames Wright local CeedScalar *elem_scratch = scratch + get_local_id(2) * T_1D * (BASIS_DIM > 1 ? T_1D : 1); 67bd882c8aSJames Wright 68bd882c8aSJames Wright loadMatrix(BASIS_P_1D * BASIS_Q_1D, d_interp_1d, s_B); 69bd882c8aSJames Wright work_group_barrier(CLK_LOCAL_MEM_FENCE); 70bd882c8aSJames Wright 71bd882c8aSJames Wright if (BASIS_DIM == 1) { 72bd882c8aSJames Wright ReadElementStrided1d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 73bd882c8aSJames Wright InterpTranspose1d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 74bd882c8aSJames Wright WriteElementStrided1d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 75bd882c8aSJames Wright 76bd882c8aSJames Wright } else if (BASIS_DIM == 2) { 77bd882c8aSJames Wright ReadElementStrided2d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 78bd882c8aSJames Wright InterpTransposeTensor2d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 79bd882c8aSJames Wright WriteElementStrided2d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 80bd882c8aSJames Wright 81bd882c8aSJames Wright } else if (BASIS_DIM == 3) { 82bd882c8aSJames Wright ReadElementStrided3d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 83bd882c8aSJames Wright InterpTransposeTensor3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 84bd882c8aSJames Wright WriteElementStrided3d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 85bd882c8aSJames Wright } 86bd882c8aSJames Wright } 87bd882c8aSJames Wright 88bd882c8aSJames Wright //------------------------------------------------------------------------------ 89bd882c8aSJames Wright // Grad kernel by dim 90bd882c8aSJames Wright //------------------------------------------------------------------------------ 91bd882c8aSJames Wright kernel void Grad(const CeedInt num_elem, global const CeedScalar *restrict d_interp_1d, global const CeedScalar *restrict d_grad_1d, 92bd882c8aSJames Wright global const CeedScalar *restrict d_U, global CeedScalar *restrict d_V) { 93bd882c8aSJames Wright local CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D]; // Todo, don't allocate s_B for dimension 1 94bd882c8aSJames Wright local CeedScalar s_G[BASIS_Q_1D * (BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D)]; 95bd882c8aSJames Wright 96bd882c8aSJames Wright private 97bd882c8aSJames Wright CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)]; 98bd882c8aSJames Wright private 99bd882c8aSJames Wright CeedScalar r_V[BASIS_NUM_COMP * BASIS_DIM * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)]; 100bd882c8aSJames Wright 101bd882c8aSJames Wright local CeedScalar scratch[BASIS_GRAD_SCRATCH_SIZE]; 102bd882c8aSJames Wright local CeedScalar *elem_scratch = scratch + get_local_id(2) * T_1D * (BASIS_DIM > 1 ? T_1D : 1); 103bd882c8aSJames Wright 104bd882c8aSJames Wright loadMatrix(BASIS_P_1D * BASIS_Q_1D, d_interp_1d, s_B); 105bd882c8aSJames Wright loadMatrix(BASIS_Q_1D * (BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D), d_grad_1d, s_G); 106bd882c8aSJames Wright work_group_barrier(CLK_LOCAL_MEM_FENCE); 107bd882c8aSJames Wright 108bd882c8aSJames Wright if (BASIS_DIM == 1) { 109bd882c8aSJames Wright ReadElementStrided1d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 110bd882c8aSJames Wright Grad1d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_G, r_V, elem_scratch); 111bd882c8aSJames Wright WriteElementStrided1d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 112bd882c8aSJames Wright 113bd882c8aSJames Wright } else if (BASIS_DIM == 2) { 114bd882c8aSJames Wright ReadElementStrided2d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 115bd882c8aSJames Wright GradTensor2d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 116bd882c8aSJames Wright WriteElementStrided2d(BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 117bd882c8aSJames Wright 118bd882c8aSJames Wright } else if (BASIS_DIM == 3) { 119bd882c8aSJames Wright ReadElementStrided3d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 120bd882c8aSJames Wright if (BASIS_HAS_COLLOCATED_GRAD) GradTensorCollocated3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 121bd882c8aSJames Wright else GradTensor3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 122bd882c8aSJames Wright WriteElementStrided3d(BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 123bd882c8aSJames Wright } 124bd882c8aSJames Wright } 125bd882c8aSJames Wright 126bd882c8aSJames Wright kernel void GradTranspose(const CeedInt num_elem, global const CeedScalar *restrict d_interp_1d, global const CeedScalar *restrict d_grad_1d, 127bd882c8aSJames Wright global const CeedScalar *restrict d_U, global CeedScalar *restrict d_V) { 128bd882c8aSJames Wright local CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D]; // Todo, don't allocate s_B for dimension 1 129bd882c8aSJames Wright local CeedScalar s_G[BASIS_Q_1D * (BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D)]; 130bd882c8aSJames Wright 131bd882c8aSJames Wright private 132bd882c8aSJames Wright CeedScalar r_U[BASIS_NUM_COMP * BASIS_DIM * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)]; 133bd882c8aSJames Wright private 134bd882c8aSJames Wright CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)]; 135bd882c8aSJames Wright 136bd882c8aSJames Wright local CeedScalar scratch[BASIS_GRAD_SCRATCH_SIZE]; 137bd882c8aSJames Wright local CeedScalar *elem_scratch = scratch + get_local_id(2) * T_1D * (BASIS_DIM > 1 ? T_1D : 1); 138bd882c8aSJames Wright 139bd882c8aSJames Wright loadMatrix(BASIS_P_1D * BASIS_Q_1D, d_interp_1d, s_B); 140bd882c8aSJames Wright loadMatrix(BASIS_Q_1D * (BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D), d_grad_1d, s_G); 141bd882c8aSJames Wright work_group_barrier(CLK_LOCAL_MEM_FENCE); 142bd882c8aSJames Wright 143bd882c8aSJames Wright if (BASIS_DIM == 1) { 144bd882c8aSJames Wright ReadElementStrided1d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 145bd882c8aSJames Wright GradTranspose1d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_G, r_V, elem_scratch); 146bd882c8aSJames Wright WriteElementStrided1d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 147bd882c8aSJames Wright 148bd882c8aSJames Wright } else if (BASIS_DIM == 2) { 149bd882c8aSJames Wright ReadElementStrided2d(BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 150bd882c8aSJames Wright GradTransposeTensor2d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 151bd882c8aSJames Wright WriteElementStrided2d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 152bd882c8aSJames Wright 153bd882c8aSJames Wright } else if (BASIS_DIM == 3) { 154bd882c8aSJames Wright ReadElementStrided3d(BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 155bd882c8aSJames Wright if (BASIS_HAS_COLLOCATED_GRAD) GradTransposeTensorCollocated3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 156bd882c8aSJames Wright else GradTransposeTensor3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 157bd882c8aSJames Wright WriteElementStrided3d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 158bd882c8aSJames Wright } 159bd882c8aSJames Wright } 160bd882c8aSJames Wright 161bd882c8aSJames Wright //------------------------------------------------------------------------------ 162bd882c8aSJames Wright // Weight kernels by dim 163bd882c8aSJames Wright //------------------------------------------------------------------------------ 164bd882c8aSJames Wright kernel void Weight(const CeedInt num_elem, global const CeedScalar *restrict q_weight_1d, global CeedScalar *restrict d_W) { 165bd882c8aSJames Wright private 166bd882c8aSJames Wright CeedScalar r_W[BASIS_DIM > 2 ? BASIS_Q_1D : 1]; 167bd882c8aSJames Wright 168bd882c8aSJames Wright // void prefetch(q_weight_1d,BASIS_Q_1D); 169bd882c8aSJames Wright 170bd882c8aSJames Wright if (BASIS_DIM == 1) { 171bd882c8aSJames Wright Weight1d(BASIS_Q_1D, q_weight_1d, r_W); 172bd882c8aSJames Wright WriteElementStrided1d(1, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_W, d_W); 173bd882c8aSJames Wright 174bd882c8aSJames Wright } else if (BASIS_DIM == 2) { 175bd882c8aSJames Wright WeightTensor2d(BASIS_Q_1D, q_weight_1d, r_W); 176bd882c8aSJames Wright WriteElementStrided2d(1, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_W, d_W); 177bd882c8aSJames Wright 178bd882c8aSJames Wright } else if (BASIS_DIM == 3) { 179bd882c8aSJames Wright WeightTensor3d(BASIS_Q_1D, q_weight_1d, r_W); 180bd882c8aSJames Wright WriteElementStrided3d(1, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_W, d_W); 181bd882c8aSJames Wright } 182bd882c8aSJames Wright } 183