xref: /libCEED/include/ceed/jit-source/hip/hip-ref-basis-nontensor-templates.h (revision d075f50ba6d3b1e38c233860adb1de6c814f0afc)
1*d075f50bSSebastian Grimberg // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2*d075f50bSSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3*d075f50bSSebastian Grimberg //
4*d075f50bSSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5*d075f50bSSebastian Grimberg //
6*d075f50bSSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7*d075f50bSSebastian Grimberg 
8*d075f50bSSebastian Grimberg /// @file
9*d075f50bSSebastian Grimberg /// Internal header for HIP non-tensor product basis templates
10*d075f50bSSebastian Grimberg #ifndef CEED_HIP_REF_BASIS_NONTENSOR_TEMPLATES_H
11*d075f50bSSebastian Grimberg #define CEED_HIP_REF_BASIS_NONTENSOR_TEMPLATES_H
12*d075f50bSSebastian Grimberg 
13*d075f50bSSebastian Grimberg #include <ceed.h>
14*d075f50bSSebastian Grimberg 
15*d075f50bSSebastian Grimberg //------------------------------------------------------------------------------
16*d075f50bSSebastian Grimberg // Tensor contraction
17*d075f50bSSebastian Grimberg //------------------------------------------------------------------------------
18*d075f50bSSebastian Grimberg template <int NUM_COMP, int Q_COMP, int P, int Q>
19*d075f50bSSebastian Grimberg inline __device__ void Contract(const CeedInt elem, const CeedInt strides_elem_U, const CeedInt strides_elem_V, const CeedInt strides_comp_U,
20*d075f50bSSebastian Grimberg                                 const CeedInt strides_comp_V, const CeedInt strides_q_comp_V, const CeedScalar *__restrict__ d_B,
21*d075f50bSSebastian Grimberg                                 const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) {
22*d075f50bSSebastian Grimberg   const CeedInt     t_id = threadIdx.x;
23*d075f50bSSebastian Grimberg   const CeedScalar *U;
24*d075f50bSSebastian Grimberg   CeedScalar        r_V[Q_COMP];
25*d075f50bSSebastian Grimberg   // TODO load B in shared memory if blockDim.z > 1?
26*d075f50bSSebastian Grimberg 
27*d075f50bSSebastian Grimberg   for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
28*d075f50bSSebastian Grimberg     // Run with Q threads
29*d075f50bSSebastian Grimberg     U = d_U + elem * strides_elem_U + comp * strides_comp_U;
30*d075f50bSSebastian Grimberg     for (CeedInt d = 0; d < Q_COMP; d++) r_V[d] = 0.0;
31*d075f50bSSebastian Grimberg     for (CeedInt i = 0; i < P; i++) {
32*d075f50bSSebastian Grimberg       const CeedScalar val = U[i];
33*d075f50bSSebastian Grimberg 
34*d075f50bSSebastian Grimberg       for (CeedInt d = 0; d < Q_COMP; d++) r_V[d] += d_B[i + t_id * P + d * P * Q] * val;
35*d075f50bSSebastian Grimberg     }
36*d075f50bSSebastian Grimberg     for (CeedInt d = 0; d < Q_COMP; d++) {
37*d075f50bSSebastian Grimberg       d_V[elem * strides_elem_V + comp * strides_comp_V + d * strides_q_comp_V + t_id] = r_V[d];
38*d075f50bSSebastian Grimberg     }
39*d075f50bSSebastian Grimberg   }
40*d075f50bSSebastian Grimberg }
41*d075f50bSSebastian Grimberg 
42*d075f50bSSebastian Grimberg //------------------------------------------------------------------------------
43*d075f50bSSebastian Grimberg // Tensor contraction transpose
44*d075f50bSSebastian Grimberg //------------------------------------------------------------------------------
45*d075f50bSSebastian Grimberg template <int NUM_COMP, int Q_COMP, int P, int Q>
46*d075f50bSSebastian Grimberg inline __device__ void ContractTranspose(const CeedInt elem, const CeedInt strides_elem_U, const CeedInt strides_elem_V, const CeedInt strides_comp_U,
47*d075f50bSSebastian Grimberg                                          const CeedInt strides_comp_V, const CeedInt strides_q_comp_U, const CeedScalar *__restrict__ d_B,
48*d075f50bSSebastian Grimberg                                          const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) {
49*d075f50bSSebastian Grimberg   const CeedInt     t_id = threadIdx.x;
50*d075f50bSSebastian Grimberg   const CeedScalar *U;
51*d075f50bSSebastian Grimberg   CeedScalar        r_V;
52*d075f50bSSebastian Grimberg   // TODO load B in shared memory if blockDim.z > 1?
53*d075f50bSSebastian Grimberg 
54*d075f50bSSebastian Grimberg   for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
55*d075f50bSSebastian Grimberg     // Run with P threads
56*d075f50bSSebastian Grimberg     r_V = 0.0;
57*d075f50bSSebastian Grimberg     for (CeedInt d = 0; d < Q_COMP; d++) {
58*d075f50bSSebastian Grimberg       U = d_U + elem * strides_elem_U + comp * strides_comp_U + d * strides_q_comp_U;
59*d075f50bSSebastian Grimberg       for (CeedInt i = 0; i < Q; i++) r_V += d_B[t_id + i * P + d * P * Q] * U[i];
60*d075f50bSSebastian Grimberg     }
61*d075f50bSSebastian Grimberg     d_V[elem * strides_elem_V + comp * strides_comp_V + t_id] = r_V;
62*d075f50bSSebastian Grimberg   }
63*d075f50bSSebastian Grimberg }
64*d075f50bSSebastian Grimberg 
65*d075f50bSSebastian Grimberg //------------------------------------------------------------------------------
66*d075f50bSSebastian Grimberg 
67*d075f50bSSebastian Grimberg #endif  // CEED_HIP_REF_BASIS_NONTENSOR_TEMPLATES_H
68