15aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2d075f50bSSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3d075f50bSSebastian Grimberg // 4d075f50bSSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5d075f50bSSebastian Grimberg // 6d075f50bSSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7d075f50bSSebastian Grimberg 8d075f50bSSebastian Grimberg /// @file 9d075f50bSSebastian Grimberg /// Internal header for HIP non-tensor product basis templates 10d075f50bSSebastian Grimberg 11d075f50bSSebastian Grimberg #include <ceed.h> 12d075f50bSSebastian Grimberg 13d075f50bSSebastian Grimberg //------------------------------------------------------------------------------ 14d075f50bSSebastian Grimberg // Tensor contraction 15d075f50bSSebastian Grimberg //------------------------------------------------------------------------------ 16d075f50bSSebastian Grimberg template <int NUM_COMP, int Q_COMP, int P, int Q> 17d075f50bSSebastian Grimberg inline __device__ void Contract(const CeedInt elem, const CeedInt strides_elem_U, const CeedInt strides_elem_V, const CeedInt strides_comp_U, 18d075f50bSSebastian Grimberg const CeedInt strides_comp_V, const CeedInt strides_q_comp_V, const CeedScalar *__restrict__ d_B, 19d075f50bSSebastian Grimberg const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) { 20d075f50bSSebastian Grimberg const CeedInt t_id = threadIdx.x; 21d075f50bSSebastian Grimberg const CeedScalar *U; 22d075f50bSSebastian Grimberg CeedScalar r_V[Q_COMP]; 23d075f50bSSebastian Grimberg // TODO load B in shared memory if blockDim.z > 1? 24d075f50bSSebastian Grimberg 25d075f50bSSebastian Grimberg for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 26d075f50bSSebastian Grimberg // Run with Q threads 27*db2becc9SJeremy L Thompson U = &d_U[elem * strides_elem_U + comp * strides_comp_U]; 28d075f50bSSebastian Grimberg for (CeedInt d = 0; d < Q_COMP; d++) r_V[d] = 0.0; 29d075f50bSSebastian Grimberg for (CeedInt i = 0; i < P; i++) { 30d075f50bSSebastian Grimberg const CeedScalar val = U[i]; 31d075f50bSSebastian Grimberg 32d075f50bSSebastian Grimberg for (CeedInt d = 0; d < Q_COMP; d++) r_V[d] += d_B[i + t_id * P + d * P * Q] * val; 33d075f50bSSebastian Grimberg } 34d075f50bSSebastian Grimberg for (CeedInt d = 0; d < Q_COMP; d++) { 35d075f50bSSebastian Grimberg d_V[elem * strides_elem_V + comp * strides_comp_V + d * strides_q_comp_V + t_id] = r_V[d]; 36d075f50bSSebastian Grimberg } 37d075f50bSSebastian Grimberg } 38d075f50bSSebastian Grimberg } 39d075f50bSSebastian Grimberg 40d075f50bSSebastian Grimberg //------------------------------------------------------------------------------ 41d075f50bSSebastian Grimberg // Tensor contraction transpose 42d075f50bSSebastian Grimberg //------------------------------------------------------------------------------ 43d075f50bSSebastian Grimberg template <int NUM_COMP, int Q_COMP, int P, int Q> 44d075f50bSSebastian Grimberg inline __device__ void ContractTranspose(const CeedInt elem, const CeedInt strides_elem_U, const CeedInt strides_elem_V, const CeedInt strides_comp_U, 45d075f50bSSebastian Grimberg const CeedInt strides_comp_V, const CeedInt strides_q_comp_U, const CeedScalar *__restrict__ d_B, 46d075f50bSSebastian Grimberg const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) { 47d075f50bSSebastian Grimberg const CeedInt t_id = threadIdx.x; 48d075f50bSSebastian Grimberg const CeedScalar *U; 49d075f50bSSebastian Grimberg CeedScalar r_V; 50d075f50bSSebastian Grimberg // TODO load B in shared memory if blockDim.z > 1? 51d075f50bSSebastian Grimberg 52d075f50bSSebastian Grimberg for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 53d075f50bSSebastian Grimberg // Run with P threads 54d075f50bSSebastian Grimberg r_V = 0.0; 55d075f50bSSebastian Grimberg for (CeedInt d = 0; d < Q_COMP; d++) { 56*db2becc9SJeremy L Thompson U = &d_U[elem * strides_elem_U + comp * strides_comp_U + d * strides_q_comp_U]; 57d075f50bSSebastian Grimberg for (CeedInt i = 0; i < Q; i++) r_V += d_B[t_id + i * P + d * P * Q] * U[i]; 58d075f50bSSebastian Grimberg } 59*db2becc9SJeremy L Thompson d_V[elem * strides_elem_V + comp * strides_comp_V + t_id] += r_V; 60d075f50bSSebastian Grimberg } 61d075f50bSSebastian Grimberg } 62