xref: /libCEED/include/ceed/jit-source/cuda/cuda-ref-basis-nontensor-templates.h (revision db2becc9f302fe8eb3a32ace50ce3f3a5d42e6c4)
15aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2d075f50bSSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3d075f50bSSebastian Grimberg //
4d075f50bSSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5d075f50bSSebastian Grimberg //
6d075f50bSSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7d075f50bSSebastian Grimberg 
8d075f50bSSebastian Grimberg /// @file
9d075f50bSSebastian Grimberg /// Internal header for CUDA non-tensor product basis templates
10d075f50bSSebastian Grimberg 
11d075f50bSSebastian Grimberg #include <ceed.h>
12d075f50bSSebastian Grimberg 
13d075f50bSSebastian Grimberg //------------------------------------------------------------------------------
14d075f50bSSebastian Grimberg // Tensor contraction
15d075f50bSSebastian Grimberg //------------------------------------------------------------------------------
16d075f50bSSebastian Grimberg template <int NUM_COMP, int Q_COMP, int P, int Q>
17d075f50bSSebastian Grimberg inline __device__ void Contract(const CeedInt elem, const CeedInt strides_elem_U, const CeedInt strides_elem_V, const CeedInt strides_comp_U,
18d075f50bSSebastian Grimberg                                 const CeedInt strides_comp_V, const CeedInt strides_q_comp_V, const CeedScalar *__restrict__ d_B,
19d075f50bSSebastian Grimberg                                 const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) {
20d075f50bSSebastian Grimberg   const CeedInt     t_id = threadIdx.x;
21d075f50bSSebastian Grimberg   const CeedScalar *U;
22d075f50bSSebastian Grimberg   CeedScalar        r_V[Q_COMP];
23d075f50bSSebastian Grimberg   // TODO load B in shared memory if blockDim.z > 1?
24d075f50bSSebastian Grimberg 
25d075f50bSSebastian Grimberg   for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
26d075f50bSSebastian Grimberg     // Run with Q threads
27d075f50bSSebastian Grimberg     U = d_U + elem * strides_elem_U + comp * strides_comp_U;
28d075f50bSSebastian Grimberg     for (CeedInt d = 0; d < Q_COMP; d++) r_V[d] = 0.0;
29d075f50bSSebastian Grimberg     for (CeedInt i = 0; i < P; i++) {
30d075f50bSSebastian Grimberg       const CeedScalar val = U[i];
31d075f50bSSebastian Grimberg 
32d075f50bSSebastian Grimberg       for (CeedInt d = 0; d < Q_COMP; d++) r_V[d] += d_B[i + t_id * P + d * P * Q] * val;
33d075f50bSSebastian Grimberg     }
34d075f50bSSebastian Grimberg     for (CeedInt d = 0; d < Q_COMP; d++) {
35d075f50bSSebastian Grimberg       d_V[elem * strides_elem_V + comp * strides_comp_V + d * strides_q_comp_V + t_id] = r_V[d];
36d075f50bSSebastian Grimberg     }
37d075f50bSSebastian Grimberg   }
38d075f50bSSebastian Grimberg }
39d075f50bSSebastian Grimberg 
40d075f50bSSebastian Grimberg //------------------------------------------------------------------------------
41d075f50bSSebastian Grimberg // Tensor contraction transpose
42d075f50bSSebastian Grimberg //------------------------------------------------------------------------------
43d075f50bSSebastian Grimberg template <int NUM_COMP, int Q_COMP, int P, int Q>
44d075f50bSSebastian Grimberg inline __device__ void ContractTranspose(const CeedInt elem, const CeedInt strides_elem_U, const CeedInt strides_elem_V, const CeedInt strides_comp_U,
45d075f50bSSebastian Grimberg                                          const CeedInt strides_comp_V, const CeedInt strides_q_comp_U, const CeedScalar *__restrict__ d_B,
46d075f50bSSebastian Grimberg                                          const CeedScalar *__restrict__ d_U, CeedScalar *__restrict__ d_V) {
47d075f50bSSebastian Grimberg   const CeedInt     t_id = threadIdx.x;
48d075f50bSSebastian Grimberg   const CeedScalar *U;
49d075f50bSSebastian Grimberg   CeedScalar        r_V;
50d075f50bSSebastian Grimberg   // TODO load B in shared memory if blockDim.z > 1?
51d075f50bSSebastian Grimberg 
52d075f50bSSebastian Grimberg   for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
53d075f50bSSebastian Grimberg     // Run with P threads
54d075f50bSSebastian Grimberg     r_V = 0.0;
55d075f50bSSebastian Grimberg     for (CeedInt d = 0; d < Q_COMP; d++) {
56*db2becc9SJeremy L Thompson       U = &d_U[elem * strides_elem_U + comp * strides_comp_U + d * strides_q_comp_U];
57d075f50bSSebastian Grimberg       for (CeedInt i = 0; i < Q; i++) r_V += d_B[t_id + i * P + d * P * Q] * U[i];
58d075f50bSSebastian Grimberg     }
59*db2becc9SJeremy L Thompson     d_V[elem * strides_elem_V + comp * strides_comp_V + t_id] += r_V;
60d075f50bSSebastian Grimberg   }
61d075f50bSSebastian Grimberg }
62