xref: /libCEED/include/ceed/jit-source/hip/hip-ref-basis-nontensor-templates.h (revision 9dc0ea9a12d5a2dbb50983bee29c25b398979cc0)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for HIP non-tensor product basis templates
10 
11 #include <ceed.h>
12 
13 //------------------------------------------------------------------------------
14 // Tensor contraction
15 //------------------------------------------------------------------------------
16 template <int NUM_COMP, int Q_COMP, int P, int Q>
17 inline __device__ void Contract(const CeedInt elem, const CeedInt strides_elem_U, const CeedInt strides_elem_V, const CeedInt strides_comp_U,
18                                 const CeedInt strides_comp_V, const CeedInt strides_q_comp_V, const CeedScalar *__restrict__ d_B,
19                                 const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) {
20   const CeedInt     t_id = threadIdx.x;
21   const CeedScalar *U;
22   CeedScalar        r_V[Q_COMP];
23   // TODO load B in shared memory if blockDim.z > 1?
24 
25   for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
26     // Run with Q threads
27     U = d_U + elem * strides_elem_U + comp * strides_comp_U;
28     for (CeedInt d = 0; d < Q_COMP; d++) r_V[d] = 0.0;
29     for (CeedInt i = 0; i < P; i++) {
30       const CeedScalar val = U[i];
31 
32       for (CeedInt d = 0; d < Q_COMP; d++) r_V[d] += d_B[i + t_id * P + d * P * Q] * val;
33     }
34     for (CeedInt d = 0; d < Q_COMP; d++) {
35       d_V[elem * strides_elem_V + comp * strides_comp_V + d * strides_q_comp_V + t_id] = r_V[d];
36     }
37   }
38 }
39 
40 //------------------------------------------------------------------------------
41 // Tensor contraction transpose
42 //------------------------------------------------------------------------------
43 template <int NUM_COMP, int Q_COMP, int P, int Q>
44 inline __device__ void ContractTranspose(const CeedInt elem, const CeedInt strides_elem_U, const CeedInt strides_elem_V, const CeedInt strides_comp_U,
45                                          const CeedInt strides_comp_V, const CeedInt strides_q_comp_U, const CeedScalar *__restrict__ d_B,
46                                          const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) {
47   const CeedInt     t_id = threadIdx.x;
48   const CeedScalar *U;
49   CeedScalar        r_V;
50   // TODO load B in shared memory if blockDim.z > 1?
51 
52   for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
53     // Run with P threads
54     r_V = 0.0;
55     for (CeedInt d = 0; d < Q_COMP; d++) {
56       U = d_U + elem * strides_elem_U + comp * strides_comp_U + d * strides_q_comp_U;
57       for (CeedInt i = 0; i < Q; i++) r_V += d_B[t_id + i * P + d * P * Q] * U[i];
58     }
59     d_V[elem * strides_elem_V + comp * strides_comp_V + t_id] = r_V;
60   }
61 }
62