xref: /libCEED/rust/libceed-sys/c-src/include/ceed/jit-source/magma/magma-basis-interp-1d.h (revision f80f4a748154eed4bc661c135f695b92b1bc45b9)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 // macros to abstract access of shared memory and reg. file
9 #define sT(i, j) sT[(j)*P_ + (i)]
10 
11 //////////////////////////////////////////////////////////////////////////////////////////
12 // interp basis action (1D)
13 template <typename T, int DIM_, int NCOMP_, int P_, int Q_>
14 static __device__ __inline__ void magma_interp_1d_device(const T *sT, magma_trans_t transT, T *sU[NCOMP_], T *sV[NCOMP_], const int tx) {
15   // Assumptions
16   // 1. 1D threads of size max(P_,Q_)
17   // 2. sU[i] is 1xP_: in shared memory
18   // 3. sV[i] is 1xQ_: in shared memory
19   // 4. P_roduct per component is one row (1xP_) times T matrix (P_xQ_) => one row (1xQ_)
20   // 5. Each thread computes one entry in sV[i]
21   // 6. Must sync before and after call
22   // 7. Note that the layout for U and V is different from 2D/3D problem
23 
24   T rv;
25   if (tx < Q_) {
26     for (int icomp = 0; icomp < NCOMP_; icomp++) {
27       rv = (transT == MagmaTrans) ? sV[icomp][tx] : 0.0;
28       for (int i = 0; i < P_; i++) {
29         rv += sU[icomp][i] * sT(i, tx);  // sT[tx * P_ + i];
30       }
31       sV[icomp][tx] = rv;
32     }
33   }
34 }
35 
36 //////////////////////////////////////////////////////////////////////////////////////////
37 extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(MAXPQ, MAGMA_MAXTHREADS_1D)) __global__
38     void magma_interpn_1d_kernel(const CeedScalar *dT, const CeedScalar *dU, const int estrdU, const int cstrdU, CeedScalar *dV, const int estrdV,
39                                  const int cstrdV, const int nelem) {
40   MAGMA_DEVICE_SHARED(CeedScalar, shared_data)
41 
42   const int     tx      = threadIdx.x;
43   const int     ty      = threadIdx.y;
44   const int     elem_id = (blockIdx.x * blockDim.y) + ty;
45   magma_trans_t transT  = MagmaNoTrans;
46 
47   if (elem_id >= nelem) return;
48 
49   CeedScalar *sU[NCOMP];
50   CeedScalar *sV[NCOMP];
51 
52   // shift global memory pointers by elem stride
53   dU += elem_id * estrdU;
54   dV += elem_id * estrdV;
55 
56   // assign shared memory pointers
57   CeedScalar *sT = (CeedScalar *)(shared_data);
58   CeedScalar *sW = sT + P * Q;
59   sU[0]          = sW + ty * NCOMP * (P + Q);
60   sV[0]          = sU[0] + (NCOMP * 1 * P);
61   for (int icomp = 1; icomp < NCOMP; icomp++) {
62     sU[icomp] = sU[icomp - 1] + (1 * P);
63     sV[icomp] = sV[icomp - 1] + (1 * Q);
64   }
65 
66   // read T
67   if (ty == 0) {
68     dread_T_gm2sm<P, Q>(tx, transT, dT, sT);
69   }
70 
71   // read U
72   read_1d<CeedScalar, P, NCOMP>(dU, cstrdU, sU, tx);
73 
74   __syncthreads();
75   magma_interp_1d_device<CeedScalar, DIM, NCOMP, P, Q>(sT, transT, sU, sV, tx);
76   __syncthreads();
77 
78   // write V
79   write_1d<CeedScalar, Q, NCOMP>(sV, dV, cstrdV, tx);
80 }
81 
82 //////////////////////////////////////////////////////////////////////////////////////////
83 extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(MAXPQ, MAGMA_MAXTHREADS_1D)) __global__
84     void magma_interpt_1d_kernel(const CeedScalar *dT, const CeedScalar *dU, const int estrdU, const int cstrdU, CeedScalar *dV, const int estrdV,
85                                  const int cstrdV, const int nelem) {
86   MAGMA_DEVICE_SHARED(CeedScalar, shared_data)
87 
88   const int     tx      = threadIdx.x;
89   const int     ty      = threadIdx.y;
90   const int     elem_id = (blockIdx.x * blockDim.y) + ty;
91   magma_trans_t transT  = MagmaTrans;
92 
93   if (elem_id >= nelem) return;
94 
95   CeedScalar *sU[NCOMP];
96   CeedScalar *sV[NCOMP];
97 
98   // shift global memory pointers by elem stride
99   dU += elem_id * estrdU;
100   dV += elem_id * estrdV;
101 
102   // assign shared memory pointers
103   CeedScalar *sT = (CeedScalar *)(shared_data);
104   CeedScalar *sW = sT + Q * P;
105   sU[0]          = sW + ty * NCOMP * (Q + P);
106   sV[0]          = sU[0] + (NCOMP * 1 * Q);
107   for (int icomp = 1; icomp < NCOMP; icomp++) {
108     sU[icomp] = sU[icomp - 1] + (1 * Q);
109     sV[icomp] = sV[icomp - 1] + (1 * P);
110   }
111 
112   // read T
113   if (ty == 0) {
114     dread_T_gm2sm<Q, P>(tx, transT, dT, sT);
115   }
116 
117   // read U
118   read_1d<CeedScalar, Q, NCOMP>(dU, cstrdU, sU, tx);
119 
120   // read V
121   read_1d<CeedScalar, P, NCOMP>(dV, cstrdV, sV, tx);
122 
123   __syncthreads();
124   magma_interp_1d_device<CeedScalar, DIM, NCOMP, Q, P>(sT, transT, sU, sV, tx);
125   __syncthreads();
126 
127   // write V
128   write_1d<CeedScalar, P, NCOMP>(sV, dV, cstrdV, tx);
129 }
130