1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 /// @file 9 /// Internal header for SYCL shared memory basis read/write templates 10 11 #include <ceed.h> 12 #include "sycl-types.h" 13 14 //------------------------------------------------------------------------------ 15 // Helper function: load matrices for basis actions 16 //------------------------------------------------------------------------------ 17 inline void loadMatrix(const CeedInt N, const CeedScalar *restrict d_B, CeedScalar *restrict B) { 18 const CeedInt item_id = get_local_linear_id(); 19 const CeedInt group_size = get_local_size(0) * get_local_size(1) * get_local_size(2); 20 for (CeedInt i = item_id; i < N; i += group_size) B[i] = d_B[i]; 21 } 22 23 //------------------------------------------------------------------------------ 24 // 1D 25 //------------------------------------------------------------------------------ 26 27 //------------------------------------------------------------------------------ 28 // E-vector -> single element 29 //------------------------------------------------------------------------------ 30 inline void ReadElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 31 const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u, 32 private CeedScalar *restrict r_u) { 33 const CeedInt item_id_x = get_local_id(0); 34 const CeedInt elem = get_global_id(2); 35 36 if (item_id_x < P_1D && elem < num_elem) { 37 const CeedInt node = item_id_x; 38 const CeedInt ind = node * strides_node + elem * strides_elem; 39 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 40 r_u[comp] = d_u[ind + comp * strides_comp]; 41 } 42 } 43 } 44 45 //------------------------------------------------------------------------------ 46 // Single element -> E-vector 47 //------------------------------------------------------------------------------ 48 inline void WriteElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 49 const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v, 50 global CeedScalar *restrict d_v) { 51 const CeedInt item_id_x = get_local_id(0); 52 const CeedInt elem = get_global_id(2); 53 54 if (item_id_x < P_1D && elem < num_elem) { 55 const CeedInt node = item_id_x; 56 const CeedInt ind = node * strides_node + elem * strides_elem; 57 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 58 d_v[ind + comp * strides_comp] = r_v[comp]; 59 } 60 } 61 } 62 63 //------------------------------------------------------------------------------ 64 // 2D 65 //------------------------------------------------------------------------------ 66 67 //------------------------------------------------------------------------------ 68 // E-vector -> single element 69 //------------------------------------------------------------------------------ 70 inline void ReadElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 71 const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u, 72 private CeedScalar *restrict r_u) { 73 const CeedInt item_id_x = get_local_id(0); 74 const CeedInt item_id_y = get_local_id(1); 75 const CeedInt elem = get_global_id(2); 76 77 if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) { 78 const CeedInt node = item_id_x + item_id_y * P_1D; 79 const CeedInt ind = node * strides_node + elem * strides_elem; 80 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 81 r_u[comp] = d_u[ind + comp * strides_comp]; 82 } 83 } 84 } 85 86 //------------------------------------------------------------------------------ 87 // Single element -> E-vector 88 //------------------------------------------------------------------------------ 89 inline void WriteElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 90 const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v, 91 global CeedScalar *restrict d_v) { 92 const CeedInt item_id_x = get_local_id(0); 93 const CeedInt item_id_y = get_local_id(1); 94 const CeedInt elem = get_global_id(2); 95 96 if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) { 97 const CeedInt node = item_id_x + item_id_y * P_1D; 98 const CeedInt ind = node * strides_node + elem * strides_elem; 99 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 100 d_v[ind + comp * strides_comp] = r_v[comp]; 101 } 102 } 103 } 104 105 //------------------------------------------------------------------------------ 106 // 3D 107 //------------------------------------------------------------------------------ 108 109 //------------------------------------------------------------------------------ 110 // E-vector -> single element 111 //------------------------------------------------------------------------------ 112 inline void ReadElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 113 const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u, 114 private CeedScalar *restrict r_u) { 115 const CeedInt item_id_x = get_local_id(0); 116 const CeedInt item_id_y = get_local_id(1); 117 const CeedInt elem = get_global_id(2); 118 119 if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) { 120 for (CeedInt z = 0; z < P_1D; z++) { 121 const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D; 122 const CeedInt ind = node * strides_node + elem * strides_elem; 123 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 124 r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp]; 125 } 126 } 127 } 128 } 129 130 //------------------------------------------------------------------------------ 131 // Single element -> E-vector 132 //------------------------------------------------------------------------------ 133 inline void WriteElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 134 const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v, 135 global CeedScalar *restrict d_v) { 136 const CeedInt item_id_x = get_local_id(0); 137 const CeedInt item_id_y = get_local_id(1); 138 const CeedInt elem = get_global_id(2); 139 140 if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) { 141 for (CeedInt z = 0; z < P_1D; z++) { 142 const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D; 143 const CeedInt ind = node * strides_node + elem * strides_elem; 144 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 145 d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D]; 146 } 147 } 148 } 149 } 150