xref: /libCEED/include/ceed/jit-source/magma/magma-basis-grad-3d.h (revision 715f9ba89a309f24226005ca1fbb9f59fe9eac68)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for MAGMA tensor basis gradient in 3D
10 #ifndef CEED_MAGMA_BASIS_GRAD_3D_H
11 #define CEED_MAGMA_BASIS_GRAD_3D_H
12 
13 #include "magma-common-tensor.h"
14 
15 // macros to abstract access of shared memory and reg. file
16 #define sT(i, j) sT[(j)*P + (i)]
17 #define sTmp(i, j, ldw) sTmp[(j) * (ldw) + (i)]
18 #define sTmp2(i, j, ldw) sTmp2[(j) * (ldw) + (i)]
19 
20 //////////////////////////////////////////////////////////////////////////////////////////
21 // grad basis action (3D)
22 // This function is called three times at a higher level for 3D
23 // DIM_U   -- for the size of rU[DIM_U * NUM_COMP * MAX_P_Q]
24 // DIM_V   -- for the size of rV[DIM_V * NUM_COMP * MAX_P_Q]
25 // i_DIM   -- the index of the outermost loop over dimensions in grad
26 // i_DIM_U -- which dim index of rU is accessed (always 0 for notrans, 0, 1, or 2 for trans)
27 // i_DIM_V -- which dim index of rV is accessed (0, 1, or 2 for notrans, always 0 for trans)
28 // the scalar beta is used to specify whether to accumulate to rV, or overwrite it
29 template <typename T, int DIM_U, int DIM_V, int NUM_COMP, int P, int Q, int rU_SIZE, int rV_SIZE, int i_DIM, int i_DIM_U, int i_DIM_V>
30 static __device__ __inline__ void magma_grad_3d_device(const T *sTinterp, const T *sTgrad, T rU[DIM_U][NUM_COMP][rU_SIZE],
31                                                        T rV[DIM_V][NUM_COMP][rV_SIZE], T beta, const int tx, T rTmp, T *swork) {
32   // Assumptions
33   // 0. This device routine applies grad for one dim only (i_DIM), so it should be thrice for 3D
34   // 1. 1D threads of size max(P,Q)^2
35   // 2. input:  rU[DIM_U x NUM_COMP x rU_SIZE] in registers (per thread)
36   // 3. output: rV[DIM_V x NUM_COMP x rV_SIZE] in registers (per thread)
37   // 4. Three products per each (dim,component) pair
38   //  4.1 Batch P^2 of (1xP) matrices times (PxQ) matrix => Batch P^2 of (1xQ) matrices
39   //  4.2 Batch P   of (QxP) matrices times (PxQ) matrix => Batch P   of (QxQ) matrices
40   //  4.3 Batch 1   of (Q^2xP_) matrix times (PxQ) matrix => (Q^2xQ_) matrix
41   // 6. Each thread computes one row of the output of each product
42   // 7. Sync is recommended before and after the call
43 
44   T *sW1 = swork;
45   T *sW2 = sW1 + P * P * Q;
46   for (int comp = 0; comp < NUM_COMP; comp++) {
47     // Batch P^2 of (1xP) matrices [reg] times (PxQ) matrix [shmem] => Batch P^2 of (1xQ) matrices [shmem]
48     if (tx < (P * P)) {
49       const int batchid = tx;
50       const int sld     = 1;
51       const T  *sT      = (i_DIM == 0) ? sTgrad : sTinterp;
52       T        *sTmp    = sW1 + batchid * (1 * Q);
53       for (int j = 0; j < Q; j++) {
54         rTmp = 0.0;
55         for (int i = 0; i < P; i++) {
56           rTmp += rU[i_DIM_U][comp][i] * sT(i, j);
57         }
58         sTmp(0, j, sld) = rTmp;
59       }
60     }  // end of: if (tx < P*P)
61     __syncthreads();
62 
63     // Batch P of (QxP) matrices [shmem] times (PxQ) matrix [shmem] => Batch P of (QxQ) matrices [reg]
64     if (tx < (P * Q)) {
65       const int batchid = tx / Q;
66       const int tx_     = tx % Q;
67       const int sld     = Q;
68       const T  *sT      = (i_DIM == 1) ? sTgrad : sTinterp;
69       T        *sTmp    = sW1 + batchid * (Q * P);  // sTmp is input
70       T        *sTmp2   = sW2 + batchid * (Q * Q);  // sTmp2 is output
71       for (int j = 0; j < Q; j++) {
72         rTmp = 0.0;
73         for (int i = 0; i < P; i++) {
74           rTmp += sTmp(tx_, i, sld) * sT(i, j);
75         }
76         sTmp2(tx_, j, sld) = rTmp;
77       }
78     }
79     __syncthreads();
80 
81     // Batch 1 of (Q^2xP_) matrices [shmem] times (PxQ) matrix [shmem] => Batch 1 of (Q^2xQ_) matrices [reg]
82     if (tx < (Q * Q)) {
83       // No need to declare batchid = (tx  / Q^2) = always zero
84       // No need to declare tx_     = (tx_ % Q^2) = always tx
85       const int sld  = Q * Q;
86       const T  *sT   = (i_DIM == 2) ? sTgrad : sTinterp;
87       T        *sTmp = sW2;  // sTmp is input
88       for (int j = 0; j < Q; j++) {
89         rTmp = 0.0;
90         for (int i = 0; i < P; i++) {
91           rTmp += sTmp(tx, i, sld) * sT(i, j);
92         }
93         rV[i_DIM_V][comp][j] *= beta;
94         rV[i_DIM_V][comp][j] += rTmp;
95       }
96     }
97     __syncthreads();
98   }  // loop over NUM_COMP
99 }
100 
101 //////////////////////////////////////////////////////////////////////////////////////////
102 extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q *BASIS_MAX_P_Q, MAGMA_MAXTHREADS_3D)) __global__
103     void magma_gradn_3d_kernel(const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, const CeedScalar *dU, const int estrdU, const int cstrdU,
104                                const int dstrdU, CeedScalar *dV, const int estrdV, const int cstrdV, const int dstrdV, const int nelem) {
105   MAGMA_DEVICE_SHARED(CeedScalar, shared_data)
106 
107   const int     tx      = threadIdx.x;
108   const int     ty      = threadIdx.y;
109   const int     elem_id = (blockIdx.x * blockDim.y) + ty;
110   magma_trans_t transT  = MagmaNoTrans;
111 
112   if (elem_id >= nelem) return;
113 
114   CeedScalar rU[1][BASIS_NUM_COMP][BASIS_P] = {0.0};  // here DIM_U = 1, but might be different for a fused operator
115   CeedScalar rV[1][BASIS_NUM_COMP][BASIS_Q] = {0.0};  // here DIM_V = 1, but might be different for a fused operator
116   CeedScalar rTmp                           = 0.0;
117 
118   // shift global memory pointers by elem stride
119   dU += elem_id * estrdU;
120   dV += elem_id * estrdV;
121 
122   // assign shared memory pointers
123   CeedScalar *sTinterp = (CeedScalar *)shared_data;
124   CeedScalar *sTgrad   = sTinterp + BASIS_P * BASIS_Q;
125   CeedScalar *sTmp     = sTgrad + BASIS_P * BASIS_Q;
126   sTmp += ty * (max(BASIS_P * BASIS_P * BASIS_P, (BASIS_P * BASIS_P * BASIS_Q) + (BASIS_P * BASIS_Q * BASIS_Q)));
127 
128   // read T
129   if (ty == 0) {
130     dread_T_gm2sm<BASIS_P, BASIS_Q>(tx, transT, dinterp1d, sTinterp);
131     dread_T_gm2sm<BASIS_P, BASIS_Q>(tx, transT, dgrad1d, sTgrad);
132   }
133   __syncthreads();
134 
135   // No need to read V ( required only in transposed grad )
136   const CeedScalar beta = 0.0;
137 
138   /* read U (idim = 0 for dU, i_DIM = 0 for rU) --
139      there is a sync at the end of this function */
140   readU_3d<CeedScalar, BASIS_P, 1, BASIS_NUM_COMP, BASIS_P, 0>(dU + (0 * dstrdU), cstrdU, rU, sTmp, tx);
141 
142   /* first call (i_DIM = 0, i_DIM_U = 0, i_DIM_V = 0) --
143      output from rV[0][][] into dV (idim = 0) */
144   magma_grad_3d_device<CeedScalar, 1, 1, BASIS_NUM_COMP, BASIS_P, BASIS_Q, BASIS_P, BASIS_Q, 0, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp);
145   /* there is a sync at the end of magma_grad_3d_device */
146   writeV_3d<CeedScalar, BASIS_Q, 1, BASIS_NUM_COMP, BASIS_Q, 0>(dV + (0 * dstrdV), cstrdV, rV, tx);
147 
148   /* second call (i_DIM = 1, i_DIM_U = 0, i_DIM_V = 0) --
149      output from rV[0][][] into dV (idim = 1) */
150   magma_grad_3d_device<CeedScalar, 1, 1, BASIS_NUM_COMP, BASIS_P, BASIS_Q, BASIS_P, BASIS_Q, 1, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp);
151   /* there is a sync at the end of magma_grad_3d_device */
152   writeV_3d<CeedScalar, BASIS_Q, 1, BASIS_NUM_COMP, BASIS_Q, 0>(dV + (1 * dstrdV), cstrdV, rV, tx);
153 
154   /* third call (i_DIM = 2, i_DIM_U = 0, i_DIM_V = 0) --
155      output from rV[0][][] into dV (idim = 2) */
156   magma_grad_3d_device<CeedScalar, 1, 1, BASIS_NUM_COMP, BASIS_P, BASIS_Q, BASIS_P, BASIS_Q, 2, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp);
157   /* there is a sync at the end of magma_grad_3d_device */
158   writeV_3d<CeedScalar, BASIS_Q, 1, BASIS_NUM_COMP, BASIS_Q, 0>(dV + (2 * dstrdV), cstrdV, rV, tx);
159 }
160 
161 //////////////////////////////////////////////////////////////////////////////////////////
162 extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_MAX_P_Q *BASIS_MAX_P_Q, MAGMA_MAXTHREADS_3D)) __global__
163     void magma_gradt_3d_kernel(const CeedScalar *dinterp1d, const CeedScalar *dgrad1d, const CeedScalar *dU, const int estrdU, const int cstrdU,
164                                const int dstrdU, CeedScalar *dV, const int estrdV, const int cstrdV, const int dstrdV, const int nelem) {
165   MAGMA_DEVICE_SHARED(CeedScalar, shared_data)
166 
167   const int     tx      = threadIdx.x;
168   const int     ty      = threadIdx.y;
169   const int     elem_id = (blockIdx.x * blockDim.y) + ty;
170   magma_trans_t transT  = MagmaTrans;
171 
172   if (elem_id >= nelem) return;
173 
174   CeedScalar rU[1][BASIS_NUM_COMP][BASIS_Q] = {0.0};  // here DIM_U = 1, but might be different for a fused operator
175   CeedScalar rV[1][BASIS_NUM_COMP][BASIS_P] = {0.0};  // here DIM_V = 1, but might be different for a fused operator
176   CeedScalar rTmp                           = 0.0;
177 
178   // shift global memory pointers by elem stride
179   dU += elem_id * estrdU;
180   dV += elem_id * estrdV;
181 
182   // assign shared memory pointers
183   CeedScalar *sTinterp = (CeedScalar *)shared_data;
184   CeedScalar *sTgrad   = sTinterp + BASIS_Q * BASIS_P;
185   CeedScalar *sTmp     = sTgrad + BASIS_Q * BASIS_P;
186   sTmp += ty * (max(BASIS_Q * BASIS_Q * BASIS_Q, (BASIS_Q * BASIS_Q * BASIS_P) + (BASIS_Q * BASIS_P * BASIS_P)));
187 
188   // read T
189   if (ty == 0) {
190     dread_T_gm2sm<BASIS_Q, BASIS_P>(tx, transT, dinterp1d, sTinterp);
191     dread_T_gm2sm<BASIS_Q, BASIS_P>(tx, transT, dgrad1d, sTgrad);
192   }
193   __syncthreads();
194 
195   // read V (since this is transposed mode)
196   const CeedScalar beta = 1.0;
197   readV_3d<CeedScalar, BASIS_P, 1, BASIS_NUM_COMP, BASIS_P, 0>(dV + (0 * dstrdV), cstrdV, rV, tx);
198 
199   /* read U (idim = 0 for dU, i_DIM = 0 for rU) --
200      there is a sync at the end of this function */
201   readU_3d<CeedScalar, BASIS_Q, 1, BASIS_NUM_COMP, BASIS_Q, 0>(dU + (0 * dstrdU), cstrdU, rU, sTmp, tx);
202   /* then first call (i_DIM = 0, i_DIM_U = 0, i_DIM_V = 0) */
203   magma_grad_3d_device<CeedScalar, 1, 1, BASIS_NUM_COMP, BASIS_Q, BASIS_P, BASIS_Q, BASIS_P, 0, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp);
204   /* there is a sync at the end of magma_grad_3d_device */
205 
206   /* read U (idim = 1 for dU, i_DIM = 0 for rU) --
207      there is a sync at the end of this function */
208   readU_3d<CeedScalar, BASIS_Q, 1, BASIS_NUM_COMP, BASIS_Q, 0>(dU + (1 * dstrdU), cstrdU, rU, sTmp, tx);
209   /* then second call (i_DIM = 1, i_DIM_U = 0, i_DIM_V = 0) */
210   magma_grad_3d_device<CeedScalar, 1, 1, BASIS_NUM_COMP, BASIS_Q, BASIS_P, BASIS_Q, BASIS_P, 1, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp);
211   /* there is a sync at the end of magma_grad_3d_device */
212 
213   /* read U (idim = 2 for dU, i_DIM = 0 for rU) --
214      there is a sync at the end of this function */
215   readU_3d<CeedScalar, BASIS_Q, 1, BASIS_NUM_COMP, BASIS_Q, 0>(dU + (2 * dstrdU), cstrdU, rU, sTmp, tx);
216   /* then third call (i_DIM = 2, i_DIM_U = 0, i_DIM_V = 0) */
217   magma_grad_3d_device<CeedScalar, 1, 1, BASIS_NUM_COMP, BASIS_Q, BASIS_P, BASIS_Q, BASIS_P, 2, 0, 0>(sTinterp, sTgrad, rU, rV, beta, tx, rTmp, sTmp);
218   /* there is a sync at the end of magma_grad_3d_device */
219 
220   // write V
221   writeV_3d<CeedScalar, BASIS_P, 1, BASIS_NUM_COMP, BASIS_P, 0>(dV + (0 * dstrdV), cstrdV, rV, tx);
222 }
223 
224 #endif  // CEED_MAGMA_BASIS_GRAD_3D_H
225