1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 /// @file 9 /// Internal header for SYCL shared memory basis read/write templates 10 #ifndef CEED_SYCL_SHARED_BASIS_READ_WRITE_TEMPLATES_H 11 #define CEED_SYCL_SHARED_BASIS_READ_WRITE_TEMPLATES_H 12 13 #include <ceed.h> 14 #include "sycl-types.h" 15 16 //------------------------------------------------------------------------------ 17 // Helper function: load matrices for basis actions 18 //------------------------------------------------------------------------------ 19 inline void loadMatrix(const CeedInt N, const CeedScalar *restrict d_B, CeedScalar *restrict B) { 20 const CeedInt item_id = get_local_linear_id(); 21 const CeedInt group_size = get_local_size(0) * get_local_size(1) * get_local_size(2); 22 for (CeedInt i = item_id; i < N; i += group_size) B[i] = d_B[i]; 23 } 24 25 //------------------------------------------------------------------------------ 26 // 1D 27 //------------------------------------------------------------------------------ 28 29 //------------------------------------------------------------------------------ 30 // E-vector -> single element 31 //------------------------------------------------------------------------------ 32 inline void ReadElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 33 const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u, 34 private CeedScalar *restrict r_u) { 35 const CeedInt item_id_x = get_local_id(0); 36 const CeedInt elem = get_global_id(2); 37 38 if (item_id_x < P_1D && elem < num_elem) { 39 const CeedInt node = item_id_x; 40 const CeedInt ind = node * strides_node + elem * strides_elem; 41 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 42 r_u[comp] = d_u[ind + comp * strides_comp]; 43 } 44 } 45 } 46 47 //------------------------------------------------------------------------------ 48 // Single element -> E-vector 49 //------------------------------------------------------------------------------ 50 inline void WriteElementStrided1d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 51 const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v, 52 global CeedScalar *restrict d_v) { 53 const CeedInt item_id_x = get_local_id(0); 54 const CeedInt elem = get_global_id(2); 55 56 if (item_id_x < P_1D && elem < num_elem) { 57 const CeedInt node = item_id_x; 58 const CeedInt ind = node * strides_node + elem * strides_elem; 59 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 60 d_v[ind + comp * strides_comp] = r_v[comp]; 61 } 62 } 63 } 64 65 //------------------------------------------------------------------------------ 66 // 2D 67 //------------------------------------------------------------------------------ 68 69 //------------------------------------------------------------------------------ 70 // E-vector -> single element 71 //------------------------------------------------------------------------------ 72 inline void ReadElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 73 const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u, 74 private CeedScalar *restrict r_u) { 75 const CeedInt item_id_x = get_local_id(0); 76 const CeedInt item_id_y = get_local_id(1); 77 const CeedInt elem = get_global_id(2); 78 79 if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) { 80 const CeedInt node = item_id_x + item_id_y * P_1D; 81 const CeedInt ind = node * strides_node + elem * strides_elem; 82 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 83 r_u[comp] = d_u[ind + comp * strides_comp]; 84 } 85 } 86 } 87 88 //------------------------------------------------------------------------------ 89 // Single element -> E-vector 90 //------------------------------------------------------------------------------ 91 inline void WriteElementStrided2d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 92 const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v, 93 global CeedScalar *restrict d_v) { 94 const CeedInt item_id_x = get_local_id(0); 95 const CeedInt item_id_y = get_local_id(1); 96 const CeedInt elem = get_global_id(2); 97 98 if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) { 99 const CeedInt node = item_id_x + item_id_y * P_1D; 100 const CeedInt ind = node * strides_node + elem * strides_elem; 101 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 102 d_v[ind + comp * strides_comp] = r_v[comp]; 103 } 104 } 105 } 106 107 //------------------------------------------------------------------------------ 108 // 3D 109 //------------------------------------------------------------------------------ 110 111 //------------------------------------------------------------------------------ 112 // E-vector -> single element 113 //------------------------------------------------------------------------------ 114 inline void ReadElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 115 const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar *restrict d_u, 116 private CeedScalar *restrict r_u) { 117 const CeedInt item_id_x = get_local_id(0); 118 const CeedInt item_id_y = get_local_id(1); 119 const CeedInt elem = get_global_id(2); 120 121 if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) { 122 for (CeedInt z = 0; z < P_1D; z++) { 123 const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D; 124 const CeedInt ind = node * strides_node + elem * strides_elem; 125 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 126 r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp]; 127 } 128 } 129 } 130 } 131 132 //------------------------------------------------------------------------------ 133 // Single element -> E-vector 134 //------------------------------------------------------------------------------ 135 inline void WriteElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, const CeedInt num_elem, const CeedInt strides_node, 136 const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar *restrict r_v, 137 global CeedScalar *restrict d_v) { 138 const CeedInt item_id_x = get_local_id(0); 139 const CeedInt item_id_y = get_local_id(1); 140 const CeedInt elem = get_global_id(2); 141 142 if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) { 143 for (CeedInt z = 0; z < P_1D; z++) { 144 const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D; 145 const CeedInt ind = node * strides_node + elem * strides_elem; 146 for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 147 d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D]; 148 } 149 } 150 } 151 } 152 153 //------------------------------------------------------------------------------ 154 155 #endif // CEED_SYCL_SHARED_BASIS_READ_WRITE_TEMPLATES_H 156