1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 /// @file 9 /// Internal header for SYCL shared memory tensor product basis 10 #include <ceed/types.h> 11 12 #include "sycl-shared-basis-read-write-templates.h" 13 #include "sycl-shared-basis-tensor-templates.h" 14 15 // 16 // BASIS_NUM_NODES = CeedIntPow(BASIS_P_1D,DIM) 17 // BASIS_NUM_QPTS = CeedIntPow(BASIS_Q_1D,DIM) 18 19 //------------------------------------------------------------------------------ 20 // Interp kernel by dim 21 //------------------------------------------------------------------------------ 22 kernel void Interp(const CeedInt num_elem, global const CeedScalar *restrict d_interp_1d, global const CeedScalar *restrict d_U, 23 global CeedScalar *restrict d_V) { 24 local CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D]; 25 private 26 CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)]; 27 private 28 CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)]; 29 30 local CeedScalar scratch[BASIS_INTERP_SCRATCH_SIZE]; 31 local CeedScalar *elem_scratch = scratch + get_local_id(2) * T_1D * (BASIS_DIM > 1 ? T_1D : 1); 32 33 loadMatrix(BASIS_P_1D * BASIS_Q_1D, d_interp_1d, s_B); 34 work_group_barrier(CLK_LOCAL_MEM_FENCE); 35 36 if (BASIS_DIM == 1) { 37 ReadElementStrided1d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 38 Interp1d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 39 WriteElementStrided1d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 40 41 } else if (BASIS_DIM == 2) { 42 ReadElementStrided2d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 43 InterpTensor2d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 44 WriteElementStrided2d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 45 46 } else if (BASIS_DIM == 3) { 47 ReadElementStrided3d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 48 InterpTensor3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 49 WriteElementStrided3d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 50 } 51 } 52 53 kernel void InterpTranspose(const CeedInt num_elem, global const CeedScalar *restrict d_interp_1d, global const CeedScalar *restrict d_U, 54 global CeedScalar *restrict d_V) { 55 // local size: 56 // 1d: elems_per_block * T_1d 57 // 2d,3d: elems_per_block * T_1d * T_1d 58 local CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D]; 59 private 60 CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)]; 61 private 62 CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)]; 63 64 local CeedScalar scratch[BASIS_INTERP_SCRATCH_SIZE]; 65 local CeedScalar *elem_scratch = scratch + get_local_id(2) * T_1D * (BASIS_DIM > 1 ? T_1D : 1); 66 67 loadMatrix(BASIS_P_1D * BASIS_Q_1D, d_interp_1d, s_B); 68 work_group_barrier(CLK_LOCAL_MEM_FENCE); 69 70 if (BASIS_DIM == 1) { 71 ReadElementStrided1d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 72 InterpTranspose1d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 73 WriteElementStrided1d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 74 75 } else if (BASIS_DIM == 2) { 76 ReadElementStrided2d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 77 InterpTransposeTensor2d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 78 WriteElementStrided2d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 79 80 } else if (BASIS_DIM == 3) { 81 ReadElementStrided3d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 82 InterpTransposeTensor3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, r_V, elem_scratch); 83 WriteElementStrided3d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 84 } 85 } 86 87 //------------------------------------------------------------------------------ 88 // Grad kernel by dim 89 //------------------------------------------------------------------------------ 90 kernel void Grad(const CeedInt num_elem, global const CeedScalar *restrict d_interp_1d, global const CeedScalar *restrict d_grad_1d, 91 global const CeedScalar *restrict d_U, global CeedScalar *restrict d_V) { 92 local CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D]; // Todo, don't allocate s_B for dimension 1 93 local CeedScalar s_G[BASIS_Q_1D * (BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D)]; 94 95 private 96 CeedScalar r_U[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)]; 97 private 98 CeedScalar r_V[BASIS_NUM_COMP * BASIS_DIM * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)]; 99 100 local CeedScalar scratch[BASIS_GRAD_SCRATCH_SIZE]; 101 local CeedScalar *elem_scratch = scratch + get_local_id(2) * T_1D * (BASIS_DIM > 1 ? T_1D : 1); 102 103 loadMatrix(BASIS_P_1D * BASIS_Q_1D, d_interp_1d, s_B); 104 loadMatrix(BASIS_Q_1D * (BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D), d_grad_1d, s_G); 105 work_group_barrier(CLK_LOCAL_MEM_FENCE); 106 107 if (BASIS_DIM == 1) { 108 ReadElementStrided1d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 109 Grad1d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_G, r_V, elem_scratch); 110 WriteElementStrided1d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 111 112 } else if (BASIS_DIM == 2) { 113 ReadElementStrided2d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 114 GradTensor2d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 115 WriteElementStrided2d(BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 116 117 } else if (BASIS_DIM == 3) { 118 ReadElementStrided3d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, d_U, r_U); 119 if (BASIS_HAS_COLLOCATED_GRAD) GradTensorCollocated3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 120 else GradTensor3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 121 WriteElementStrided3d(BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_V, d_V); 122 } 123 } 124 125 kernel void GradTranspose(const CeedInt num_elem, global const CeedScalar *restrict d_interp_1d, global const CeedScalar *restrict d_grad_1d, 126 global const CeedScalar *restrict d_U, global CeedScalar *restrict d_V) { 127 local CeedScalar s_B[BASIS_P_1D * BASIS_Q_1D]; // Todo, don't allocate s_B for dimension 1 128 local CeedScalar s_G[BASIS_Q_1D * (BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D)]; 129 130 private 131 CeedScalar r_U[BASIS_NUM_COMP * BASIS_DIM * (BASIS_DIM > 2 ? BASIS_Q_1D : 1)]; 132 private 133 CeedScalar r_V[BASIS_NUM_COMP * (BASIS_DIM > 2 ? BASIS_P_1D : 1)]; 134 135 local CeedScalar scratch[BASIS_GRAD_SCRATCH_SIZE]; 136 local CeedScalar *elem_scratch = scratch + get_local_id(2) * T_1D * (BASIS_DIM > 1 ? T_1D : 1); 137 138 loadMatrix(BASIS_P_1D * BASIS_Q_1D, d_interp_1d, s_B); 139 loadMatrix(BASIS_Q_1D * (BASIS_HAS_COLLOCATED_GRAD ? BASIS_Q_1D : BASIS_P_1D), d_grad_1d, s_G); 140 work_group_barrier(CLK_LOCAL_MEM_FENCE); 141 142 if (BASIS_DIM == 1) { 143 ReadElementStrided1d(BASIS_NUM_COMP, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 144 GradTranspose1d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_G, r_V, elem_scratch); 145 WriteElementStrided1d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 146 147 } else if (BASIS_DIM == 2) { 148 ReadElementStrided2d(BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 149 GradTransposeTensor2d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 150 WriteElementStrided2d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 151 152 } else if (BASIS_DIM == 3) { 153 ReadElementStrided3d(BASIS_NUM_COMP * BASIS_DIM, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, d_U, r_U); 154 if (BASIS_HAS_COLLOCATED_GRAD) GradTransposeTensorCollocated3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 155 else GradTransposeTensor3d(BASIS_NUM_COMP, BASIS_P_1D, BASIS_Q_1D, r_U, s_B, s_G, r_V, elem_scratch); 156 WriteElementStrided3d(BASIS_NUM_COMP, BASIS_P_1D, num_elem, 1, BASIS_NUM_NODES * num_elem, BASIS_NUM_NODES, r_V, d_V); 157 } 158 } 159 160 //------------------------------------------------------------------------------ 161 // Weight kernels by dim 162 //------------------------------------------------------------------------------ 163 kernel void Weight(const CeedInt num_elem, global const CeedScalar *restrict q_weight_1d, global CeedScalar *restrict d_W) { 164 private 165 CeedScalar r_W[BASIS_DIM > 2 ? BASIS_Q_1D : 1]; 166 167 // void prefetch(q_weight_1d,BASIS_Q_1D); 168 169 if (BASIS_DIM == 1) { 170 Weight1d(BASIS_Q_1D, q_weight_1d, r_W); 171 WriteElementStrided1d(1, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_W, d_W); 172 173 } else if (BASIS_DIM == 2) { 174 WeightTensor2d(BASIS_Q_1D, q_weight_1d, r_W); 175 WriteElementStrided2d(1, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_W, d_W); 176 177 } else if (BASIS_DIM == 3) { 178 WeightTensor3d(BASIS_Q_1D, q_weight_1d, r_W); 179 WriteElementStrided3d(1, BASIS_Q_1D, num_elem, 1, BASIS_NUM_QPTS * num_elem, BASIS_NUM_QPTS, r_W, d_W); 180 } 181 } 182