xref: /libCEED/rust/libceed-sys/c-src/include/ceed/jit-source/magma/magma-basis-grad-3d.h (revision f80f4a748154eed4bc661c135f695b92b1bc45b9)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 // macros to abstract access of shared memory and reg. file
9 #define sT(i, j) sT[(j)*P_ + (i)]
10 #define sTmp(i, j, ldw) sTmp[(j) * (ldw) + (i)]
11 #define sTmp2(i, j, ldw) sTmp2[(j) * (ldw) + (i)]
12 
13 //////////////////////////////////////////////////////////////////////////////////////////
14 // grad basis action (3D)
15 // This function is called three times at a higher level for 3D
16 // DIM_U  -- for the size of rU[DIM_U * NCOMP_ * MAXP_Q_]
17 // DIM_V  -- for the size of rV[DIM_V * NCOMP_ * MAXP_Q_]
18 // iDIM_  -- the index of the outermost loop over dimensions in grad
19 // iDIM_U -- which dim index of rU is accessed (always 0 for notrans, 0, 1, or 2 for trans)
20 // iDIM_V -- which dim index of rV is accessed (0, 1, or 2 for notrans, always 0 for trans)
21 // the scalar beta is used to specify whether to accumulate to rV, or overwrite it
22 template <typename T, int DIM_U, int DIM_V, int NCOMP_, int P_, int Q_, int rUsize, int rVsize, int iDIM_, int iDIM_U, int iDIM_V>
23 static __device__ __inline__ void magma_grad_3d_device(const T *sTinterp, const T *sTgrad, T rU[DIM_U][NCOMP_][rUsize], T rV[DIM_V][NCOMP_][rVsize],
24                                                        T beta, const int tx, T rTmp, T *swork) {
25   // Assumptions
26   // 0. This device routine applies grad for one dim only (iDIM_), so it should be thrice for 3D
27   // 1. 1D threads of size max(P_,Q_)^2
28   // 2. input:  rU[DIM_U x NCOMP_ x rUsize] in registers (per thread)
29   // 3. output: rV[DIM_V x NCOMP_ x rVsize] in registers (per thread)
30   // 4. Three products per each (dim,component) pair
31   //  4.1 Batch P_^2 of (1xP_) matrices times (P_xQ_) matrix => Batch P_^2 of (1xQ_) matrices
32   //  4.2 Batch P_   of (Q_xP_) matrices times (P_xQ_) matrix => Batch P_   of (Q_xQ_) matrices
33   //  4.3 Batch 1   of (Q_^2xP_) matrix times (P_xQ_) matrix => (Q_^2xQ_) matrix
34   // 6. Each thread computes one row of the output of each product
35   // 7. Sync is recommended before and after the call
36 
37   T *sW1 = swork;
38   T *sW2 = sW1 + P_ * P_ * Q_;
39   for (int icomp = 0; icomp < NCOMP_; icomp++) {
40     // Batch P_^2 of (1xP_) matrices [reg] times (P_xQ_) matrix [shmem] => Batch P_^2 of (1xQ_) matrices [shmem]
41     if (tx < (P_ * P_)) {
42       const int batchid = tx;
43       const int sld     = 1;
44       const T  *sT      = (iDIM_ == 0) ? sTgrad : sTinterp;
45       T        *sTmp    = sW1 + batchid * (1 * Q_);
46       for (int j = 0; j < Q_; j++) {
47         rTmp = 0.0;
48         for (int i = 0; i < P_; i++) {
49           rTmp += rU[iDIM_U][icomp][i] * sT(i, j);
50         }
51         sTmp(0, j, sld) = rTmp;
52       }
53     }  // end of: if (tx < P_*P_)
54     __syncthreads();
55 
56     // Batch P_ of (Q_xP_) matrices [shmem] times (P_xQ_) matrix [shmem] => Batch P_ of (Q_xQ_) matrices [reg]
57     if (tx < (P_ * Q_)) {
58       const int batchid = tx / Q_;
59       const int tx_     = tx % Q_;
60       const int sld     = Q_;
61       const T  *sT      = (iDIM_ == 1) ? sTgrad : sTinterp;
62       T        *sTmp    = sW1 + batchid * (Q_ * P_);  // sTmp is input
63       T        *sTmp2   = sW2 + batchid * (Q_ * Q_);  // sTmp2 is output
64       for (int j = 0; j < Q_; j++) {
65         rTmp = 0.0;
66         for (int i = 0; i < P_; i++) {
67           rTmp += sTmp(tx_, i, sld) * sT(i, j);
68         }
69         sTmp2(tx_, j, sld) = rTmp;
70       }
71     }
72     __syncthreads();
73 
74     // Batch 1 of (Q_^2xP_) matrices [shmem] times (P_xQ_) matrix [shmem] => Batch 1 of (Q_^2xQ_) matrices [reg]
75     if (tx < (Q_ * Q_)) {
76       // No need to declare batchid = (tx  / Q_^2) = always zero
77       // No need to declare tx_     = (tx_ % Q_^2) = always tx
78       const int sld  = Q_ * Q_;
79       const T  *sT   = (iDIM_ == 2) ? sTgrad : sTinterp;
80       T        *sTmp = sW2;  // sTmp is input
81       for (int j = 0; j < Q_; j++) {
82         rTmp = 0.0;
83         for (int i = 0; i < P_; i++) {
84           rTmp += sTmp(tx, i, sld) * sT(i, j);
85         }
86         rV[iDIM_V][icomp][j] *= beta;
87         rV[iDIM_V][icomp][j] += rTmp;
88       }
89     }
90     __syncthreads();
91   }  // loop over NCOMP_
92 }
93 
94 //////////////////////////////////////////////////////////////////////////////////////////
95 extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(MAXPQ *MAXPQ, MAGMA_MAXTHREADS_3D)) __global__
96     void magma_gradn_3d_kernel(const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, const CeedScalar *dU, const int estrdU, const int cstrdU,
97                                const int dstrdU, CeedScalar *dV, const int estrdV, const int cstrdV, const int dstrdV, const int nelem) {
98   MAGMA_DEVICE_SHARED(CeedScalar, shared_data)
99 
100   const int     tx      = threadIdx.x;
101   const int     ty      = threadIdx.y;
102   const int     elem_id = (blockIdx.x * blockDim.y) + ty;
103   magma_trans_t transT  = MagmaNoTrans;
104 
105   if (elem_id >= nelem) return;
106 
107   CeedScalar rU[1][NCOMP][P] = {0.0};  // here DIMU = 1, but might be different for a fused operator
108   CeedScalar rV[1][NCOMP][Q] = {0.0};  // here DIMV = 1, but might be different for a fused operator
109   CeedScalar rTmp            = 0.0;
110 
111   // shift global memory pointers by elem stride
112   dU += elem_id * estrdU;
113   dV += elem_id * estrdV;
114 
115   // assign shared memory pointers
116   CeedScalar *sTinterp = (CeedScalar *)(shared_data);
117   CeedScalar *sTgrad   = sTinterp + P * Q;
118   CeedScalar *sTmp     = sTgrad + P * Q;
119   sTmp += ty * (max(P * P * P, (P * P * Q) + (P * Q * Q)));
120 
121   // read T
122   if (ty == 0) {
123     dread_T_gm2sm<P, Q>(tx, transT, dinterp1d, sTinterp);
124     dread_T_gm2sm<P, Q>(tx, transT, dgrad1d, sTgrad);
125   }
126   __syncthreads();
127 
128   // No need to read V ( required only in transposed grad )
129   const CeedScalar beta = 0.0;
130 
131   /* read U (idim = 0 for dU, iDIM = 0 for rU) --
132      there is a sync at the end of this function */
133   readU_3d<CeedScalar, P, 1, NCOMP, P, 0>(dU + (0 * dstrdU), cstrdU, rU, sTmp, tx);
134 
135   /* first call (iDIM = 0, iDIMU = 0, iDIMV = 0) --
136      output from rV[0][][] into dV (idim = 0) */
137   magma_grad_3d_device<CeedScalar, 1, 1, NCOMP, P, Q, P, Q, 0, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp);
138   /* there is a sync at the end of magma_grad_3d_device */
139   writeV_3d<CeedScalar, Q, 1, NCOMP, Q, 0>(dV + (0 * dstrdV), cstrdV, rV, tx);
140 
141   /* second call (iDIM = 1, iDIMU = 0, iDIMV = 0) --
142      output from rV[0][][] into dV (idim = 1) */
143   magma_grad_3d_device<CeedScalar, 1, 1, NCOMP, P, Q, P, Q, 1, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp);
144   /* there is a sync at the end of magma_grad_3d_device */
145   writeV_3d<CeedScalar, Q, 1, NCOMP, Q, 0>(dV + (1 * dstrdV), cstrdV, rV, tx);
146 
147   /* third call (iDIM = 2, iDIMU = 0, iDIMV = 0) --
148      output from rV[0][][] into dV (idim = 2) */
149   magma_grad_3d_device<CeedScalar, 1, 1, NCOMP, P, Q, P, Q, 2, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp);
150   /* there is a sync at the end of magma_grad_3d_device */
151   writeV_3d<CeedScalar, Q, 1, NCOMP, Q, 0>(dV + (2 * dstrdV), cstrdV, rV, tx);
152 }
153 
154 //////////////////////////////////////////////////////////////////////////////////////////
155 extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(MAXPQ *MAXPQ, MAGMA_MAXTHREADS_3D)) __global__
156     void magma_gradt_3d_kernel(const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, const CeedScalar *dU, const int estrdU, const int cstrdU,
157                                const int dstrdU, CeedScalar *dV, const int estrdV, const int cstrdV, const int dstrdV, const int nelem) {
158   MAGMA_DEVICE_SHARED(CeedScalar, shared_data)
159 
160   const int     tx      = threadIdx.x;
161   const int     ty      = threadIdx.y;
162   const int     elem_id = (blockIdx.x * blockDim.y) + ty;
163   magma_trans_t transT  = MagmaTrans;
164 
165   if (elem_id >= nelem) return;
166 
167   CeedScalar rU[1][NCOMP][Q] = {0.0};  // here DIMU = 1, but might be different for a fused operator
168   CeedScalar rV[1][NCOMP][P] = {0.0};  // here DIMV = 1, but might be different for a fused operator
169   CeedScalar rTmp            = 0.0;
170 
171   // shift global memory pointers by elem stride
172   dU += elem_id * estrdU;
173   dV += elem_id * estrdV;
174 
175   // assign shared memory pointers
176   CeedScalar *sTinterp = (CeedScalar *)(shared_data);
177   CeedScalar *sTgrad   = sTinterp + Q * P;
178   CeedScalar *sTmp     = sTgrad + Q * P;
179   sTmp += ty * (max(Q * Q * Q, (Q * Q * P) + (Q * P * P)));
180 
181   // read T
182   if (ty == 0) {
183     dread_T_gm2sm<Q, P>(tx, transT, dinterp1d, sTinterp);
184     dread_T_gm2sm<Q, P>(tx, transT, dgrad1d, sTgrad);
185   }
186   __syncthreads();
187 
188   // read V (since this is transposed mode)
189   const CeedScalar beta = 1.0;
190   readV_3d<CeedScalar, P, 1, NCOMP, P, 0>(dV + (0 * dstrdV), cstrdV, rV, tx);
191 
192   /* read U (idim = 0 for dU, iDIM = 0 for rU) --
193      there is a sync at the end of this function */
194   readU_3d<CeedScalar, Q, 1, NCOMP, Q, 0>(dU + (0 * dstrdU), cstrdU, rU, sTmp, tx);
195   /* then first call (iDIM = 0, iDIMU = 0, iDIMV = 0) */
196   magma_grad_3d_device<CeedScalar, 1, 1, NCOMP, Q, P, Q, P, 0, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp);
197   /* there is a sync at the end of magma_grad_3d_device */
198 
199   /* read U (idim = 1 for dU, iDIM = 0 for rU) --
200      there is a sync at the end of this function */
201   readU_3d<CeedScalar, Q, 1, NCOMP, Q, 0>(dU + (1 * dstrdU), cstrdU, rU, sTmp, tx);
202   /* then second call (iDIM = 1, iDIMU = 0, iDIMV = 0) */
203   magma_grad_3d_device<CeedScalar, 1, 1, NCOMP, Q, P, Q, P, 1, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp);
204   /* there is a sync at the end of magma_grad_3d_device */
205 
206   /* read U (idim = 2 for dU, iDIM = 0 for rU) --
207      there is a sync at the end of this function */
208   readU_3d<CeedScalar, Q, 1, NCOMP, Q, 0>(dU + (2 * dstrdU), cstrdU, rU, sTmp, tx);
209   /* then third call (iDIM = 2, iDIMU = 0, iDIMV = 0) */
210   magma_grad_3d_device<CeedScalar, 1, 1, NCOMP, Q, P, Q, P, 2, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp);
211   /* there is a sync at the end of magma_grad_3d_device */
212 
213   // write V
214   writeV_3d<CeedScalar, P, 1, NCOMP, P, 0>(dV + (0 * dstrdV), cstrdV, rV, tx);
215 }
216