15aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2bd882c8aSJames Wright // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3bd882c8aSJames Wright // 4bd882c8aSJames Wright // SPDX-License-Identifier: BSD-2-Clause 5bd882c8aSJames Wright // 6bd882c8aSJames Wright // This file is part of CEED: http://github.com/ceed 7bd882c8aSJames Wright 8bd882c8aSJames Wright /// @file 9bd882c8aSJames Wright /// Internal header for SYCL shared memory basis read/write templates 10bd882c8aSJames Wright 11*c0b5abf0SJeremy L Thompson #include <ceed/types.h> 12bd882c8aSJames Wright #include "sycl-types.h" 13bd882c8aSJames Wright 14bd882c8aSJames Wright //------------------------------------------------------------------------------ 15bd882c8aSJames Wright // Helper function: load matrices for basis actions 16bd882c8aSJames Wright //------------------------------------------------------------------------------ 17bd882c8aSJames Wright inline void loadMatrix(const CeedInt N, const CeedScalar *restrict d_B, CeedScalar *restrict B) { 18bd882c8aSJames Wright const CeedInt item_id = get_local_linear_id(); 19bd882c8aSJames Wright const CeedInt group_size = get_local_size(0) * get_local_size(1) * get_local_size(2); 20bd882c8aSJames Wright for (CeedInt i = item_id; i < N; i += group_size) B[i] = d_B[i]; 21bd882c8aSJames Wright } 22bd882c8aSJames Wright 23bd882c8aSJames Wright //------------------------------------------------------------------------------ 24bd882c8aSJames Wright // 1D 25bd882c8aSJames Wright //------------------------------------------------------------------------------ 26bd882c8aSJames Wright 27bd882c8aSJames Wright //------------------------------------------------------------------------------ 28bd882c8aSJames Wright // E-vector -> single element 29bd882c8aSJames Wright //------------------------------------------------------------------------------ 30bd882c8aSJames Wright inline void ReadElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 31bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u, 32bd882c8aSJames Wright private CeedScalar *restrict r_u) { 33bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0); 34bd882c8aSJames Wright const CeedInt elem = get_global_id(2); 35bd882c8aSJames Wright 36bd882c8aSJames Wright if (item_id_x < P_1D && elem < num_elem) { 37bd882c8aSJames Wright const CeedInt node = item_id_x; 38bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem; 39bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 40bd882c8aSJames Wright r_u[comp] = d_u[ind + comp * strides_comp]; 41bd882c8aSJames Wright } 42bd882c8aSJames Wright } 43bd882c8aSJames Wright } 44bd882c8aSJames Wright 45bd882c8aSJames Wright //------------------------------------------------------------------------------ 46bd882c8aSJames Wright // Single element -> E-vector 47bd882c8aSJames Wright //------------------------------------------------------------------------------ 48bd882c8aSJames Wright inline void WriteElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 49bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v, 50bd882c8aSJames Wright global CeedScalar *restrict d_v) { 51bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0); 52bd882c8aSJames Wright const CeedInt elem = get_global_id(2); 53bd882c8aSJames Wright 54bd882c8aSJames Wright if (item_id_x < P_1D && elem < num_elem) { 55bd882c8aSJames Wright const CeedInt node = item_id_x; 56bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem; 57bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 58bd882c8aSJames Wright d_v[ind + comp * strides_comp] = r_v[comp]; 59bd882c8aSJames Wright } 60bd882c8aSJames Wright } 61bd882c8aSJames Wright } 62bd882c8aSJames Wright 63bd882c8aSJames Wright //------------------------------------------------------------------------------ 64bd882c8aSJames Wright // 2D 65bd882c8aSJames Wright //------------------------------------------------------------------------------ 66bd882c8aSJames Wright 67bd882c8aSJames Wright //------------------------------------------------------------------------------ 68bd882c8aSJames Wright // E-vector -> single element 69bd882c8aSJames Wright //------------------------------------------------------------------------------ 70bd882c8aSJames Wright inline void ReadElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 71bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u, 72bd882c8aSJames Wright private CeedScalar *restrict r_u) { 73bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0); 74bd882c8aSJames Wright const CeedInt item_id_y = get_local_id(1); 75bd882c8aSJames Wright const CeedInt elem = get_global_id(2); 76bd882c8aSJames Wright 77bd882c8aSJames Wright if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) { 78bd882c8aSJames Wright const CeedInt node = item_id_x + item_id_y * P_1D; 79bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem; 80bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 81bd882c8aSJames Wright r_u[comp] = d_u[ind + comp * strides_comp]; 82bd882c8aSJames Wright } 83bd882c8aSJames Wright } 84bd882c8aSJames Wright } 85bd882c8aSJames Wright 86bd882c8aSJames Wright //------------------------------------------------------------------------------ 87bd882c8aSJames Wright // Single element -> E-vector 88bd882c8aSJames Wright //------------------------------------------------------------------------------ 89bd882c8aSJames Wright inline void WriteElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 90bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v, 91bd882c8aSJames Wright global CeedScalar *restrict d_v) { 92bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0); 93bd882c8aSJames Wright const CeedInt item_id_y = get_local_id(1); 94bd882c8aSJames Wright const CeedInt elem = get_global_id(2); 95bd882c8aSJames Wright 96bd882c8aSJames Wright if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) { 97bd882c8aSJames Wright const CeedInt node = item_id_x + item_id_y * P_1D; 98bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem; 99bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 100bd882c8aSJames Wright d_v[ind + comp * strides_comp] = r_v[comp]; 101bd882c8aSJames Wright } 102bd882c8aSJames Wright } 103bd882c8aSJames Wright } 104bd882c8aSJames Wright 105bd882c8aSJames Wright //------------------------------------------------------------------------------ 106bd882c8aSJames Wright // 3D 107bd882c8aSJames Wright //------------------------------------------------------------------------------ 108bd882c8aSJames Wright 109bd882c8aSJames Wright //------------------------------------------------------------------------------ 110bd882c8aSJames Wright // E-vector -> single element 111bd882c8aSJames Wright //------------------------------------------------------------------------------ 112bd882c8aSJames Wright inline void ReadElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 113bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u, 114bd882c8aSJames Wright private CeedScalar *restrict r_u) { 115bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0); 116bd882c8aSJames Wright const CeedInt item_id_y = get_local_id(1); 117bd882c8aSJames Wright const CeedInt elem = get_global_id(2); 118bd882c8aSJames Wright 119bd882c8aSJames Wright if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) { 120bd882c8aSJames Wright for (CeedInt z = 0; z < P_1D; z++) { 121bd882c8aSJames Wright const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D; 122bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem; 123bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 124bd882c8aSJames Wright r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp]; 125bd882c8aSJames Wright } 126bd882c8aSJames Wright } 127bd882c8aSJames Wright } 128bd882c8aSJames Wright } 129bd882c8aSJames Wright 130bd882c8aSJames Wright //------------------------------------------------------------------------------ 131bd882c8aSJames Wright // Single element -> E-vector 132bd882c8aSJames Wright //------------------------------------------------------------------------------ 133bd882c8aSJames Wright inline void WriteElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 134bd882c8aSJames Wright const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v, 135bd882c8aSJames Wright global CeedScalar *restrict d_v) { 136bd882c8aSJames Wright const CeedInt item_id_x = get_local_id(0); 137bd882c8aSJames Wright const CeedInt item_id_y = get_local_id(1); 138bd882c8aSJames Wright const CeedInt elem = get_global_id(2); 139bd882c8aSJames Wright 140bd882c8aSJames Wright if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) { 141bd882c8aSJames Wright for (CeedInt z = 0; z < P_1D; z++) { 142bd882c8aSJames Wright const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D; 143bd882c8aSJames Wright const CeedInt ind = node * strides_node + elem * strides_elem; 144bd882c8aSJames Wright for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 145bd882c8aSJames Wright d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D]; 146bd882c8aSJames Wright } 147bd882c8aSJames Wright } 148bd882c8aSJames Wright } 149bd882c8aSJames Wright } 150