1*d075f50bSSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2*d075f50bSSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3*d075f50bSSebastian Grimberg // 4*d075f50bSSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5*d075f50bSSebastian Grimberg // 6*d075f50bSSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7*d075f50bSSebastian Grimberg 8*d075f50bSSebastian Grimberg /// @file 9*d075f50bSSebastian Grimberg /// Internal header for HIP non-tensor product basis templates 10*d075f50bSSebastian Grimberg #ifndef CEED_HIP_REF_BASIS_NONTENSOR_TEMPLATES_H 11*d075f50bSSebastian Grimberg #define CEED_HIP_REF_BASIS_NONTENSOR_TEMPLATES_H 12*d075f50bSSebastian Grimberg 13*d075f50bSSebastian Grimberg #include <ceed.h> 14*d075f50bSSebastian Grimberg 15*d075f50bSSebastian Grimberg //------------------------------------------------------------------------------ 16*d075f50bSSebastian Grimberg // Tensor contraction 17*d075f50bSSebastian Grimberg //------------------------------------------------------------------------------ 18*d075f50bSSebastian Grimberg template <int NUM_COMP, int Q_COMP, int P, int Q> 19*d075f50bSSebastian Grimberg inline __device__ void Contract(const CeedInt elem, const CeedInt strides_elem_U, const CeedInt strides_elem_V, const CeedInt strides_comp_U, 20*d075f50bSSebastian Grimberg const CeedInt strides_comp_V, const CeedInt strides_q_comp_V, const CeedScalar *__restrict__ d_B, 21*d075f50bSSebastian Grimberg const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) { 22*d075f50bSSebastian Grimberg const CeedInt t_id = threadIdx.x; 23*d075f50bSSebastian Grimberg const CeedScalar *U; 24*d075f50bSSebastian Grimberg CeedScalar r_V[Q_COMP]; 25*d075f50bSSebastian Grimberg // TODO load B in shared memory if blockDim.z > 1? 26*d075f50bSSebastian Grimberg 27*d075f50bSSebastian Grimberg for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 28*d075f50bSSebastian Grimberg // Run with Q threads 29*d075f50bSSebastian Grimberg U = d_U + elem * strides_elem_U + comp * strides_comp_U; 30*d075f50bSSebastian Grimberg for (CeedInt d = 0; d < Q_COMP; d++) r_V[d] = 0.0; 31*d075f50bSSebastian Grimberg for (CeedInt i = 0; i < P; i++) { 32*d075f50bSSebastian Grimberg const CeedScalar val = U[i]; 33*d075f50bSSebastian Grimberg 34*d075f50bSSebastian Grimberg for (CeedInt d = 0; d < Q_COMP; d++) r_V[d] += d_B[i + t_id * P + d * P * Q] * val; 35*d075f50bSSebastian Grimberg } 36*d075f50bSSebastian Grimberg for (CeedInt d = 0; d < Q_COMP; d++) { 37*d075f50bSSebastian Grimberg d_V[elem * strides_elem_V + comp * strides_comp_V + d * strides_q_comp_V + t_id] = r_V[d]; 38*d075f50bSSebastian Grimberg } 39*d075f50bSSebastian Grimberg } 40*d075f50bSSebastian Grimberg } 41*d075f50bSSebastian Grimberg 42*d075f50bSSebastian Grimberg //------------------------------------------------------------------------------ 43*d075f50bSSebastian Grimberg // Tensor contraction transpose 44*d075f50bSSebastian Grimberg //------------------------------------------------------------------------------ 45*d075f50bSSebastian Grimberg template <int NUM_COMP, int Q_COMP, int P, int Q> 46*d075f50bSSebastian Grimberg inline __device__ void ContractTranspose(const CeedInt elem, const CeedInt strides_elem_U, const CeedInt strides_elem_V, const CeedInt strides_comp_U, 47*d075f50bSSebastian Grimberg const CeedInt strides_comp_V, const CeedInt strides_q_comp_U, const CeedScalar *__restrict__ d_B, 48*d075f50bSSebastian Grimberg const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) { 49*d075f50bSSebastian Grimberg const CeedInt t_id = threadIdx.x; 50*d075f50bSSebastian Grimberg const CeedScalar *U; 51*d075f50bSSebastian Grimberg CeedScalar r_V; 52*d075f50bSSebastian Grimberg // TODO load B in shared memory if blockDim.z > 1? 53*d075f50bSSebastian Grimberg 54*d075f50bSSebastian Grimberg for (CeedInt comp = 0; comp < NUM_COMP; comp++) { 55*d075f50bSSebastian Grimberg // Run with P threads 56*d075f50bSSebastian Grimberg r_V = 0.0; 57*d075f50bSSebastian Grimberg for (CeedInt d = 0; d < Q_COMP; d++) { 58*d075f50bSSebastian Grimberg U = d_U + elem * strides_elem_U + comp * strides_comp_U + d * strides_q_comp_U; 59*d075f50bSSebastian Grimberg for (CeedInt i = 0; i < Q; i++) r_V += d_B[t_id + i * P + d * P * Q] * U[i]; 60*d075f50bSSebastian Grimberg } 61*d075f50bSSebastian Grimberg d_V[elem * strides_elem_V + comp * strides_comp_V + t_id] = r_V; 62*d075f50bSSebastian Grimberg } 63*d075f50bSSebastian Grimberg } 64*d075f50bSSebastian Grimberg 65*d075f50bSSebastian Grimberg //------------------------------------------------------------------------------ 66*d075f50bSSebastian Grimberg 67*d075f50bSSebastian Grimberg #endif // CEED_HIP_REF_BASIS_NONTENSOR_TEMPLATES_H 68