// Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. // // SPDX-License-Identifier: BSD-2-Clause // // This file is part of CEED: http://github.com/ceed #ifndef CEED_MAGMA_COMMON_NONTENSOR_H #define CEED_MAGMA_COMMON_NONTENSOR_H #define NONTENSOR_MAX_THREADS (128) #ifndef MAGMA_DEVICE_SHARED #define MAGMA_DEVICE_SHARED #ifdef CEED_MAGMA_USE_HIP #define MAGMA_DEVICE_SHARED(type, name) HIP_DYNAMIC_SHARED(type, name) #else #define MAGMA_DEVICE_SHARED(type, name) extern __shared__ type name[]; #endif // CEED_MAGMA_USE_HIP #endif // MAGMA_DEVICE_SHARED #define MAGMA_NONTENSOR_BASIS_NTCOL(N) (MAGMA_MAX(1, (NONTENSOR_MAX_THREADS / (N)))) #define dA(i, j) dA[(j)*ldda + (i)] #define sA(i, j) sA[(j)*slda + (i)] #define dB(i, j) dB[(j)*lddb + (i)] #define sB(i, j) sB[(j)*sldb + (i)] //////////////////////////////////////////////////////////////////////////////// // read C from global to reg. // C is (P_ x NB_) // 1D thread config. with (Mx1) threads // no sync at the end of the function template static __device__ __inline__ void read_C_g2r_1D_nosync(const int tx, const int n, T *dC, int lddc, const T &beta, T rC[NB_]) { if (n != NB_) { #pragma unroll for (int j = 0; j < NB_; j++) { rC[j] = (j < n) ? beta * dC[j * lddc + tx] : 0; } } else { #pragma unroll for (int j = 0; j < NB_; j++) { rC[j] = beta * dC[j * lddc + tx]; } } } //////////////////////////////////////////////////////////////////////////////// // write C from reg. to global // C is (P_ x NB_) // 1D thread config. with (Mx1) threads // no sync at the end of the function template static __device__ __inline__ void write_C_r2g_1D_nosync(const int tx, const int n, T rC[NB_], T *dC, int lddc) { if (n != NB_) { #pragma unroll for (int j = 0; j < NB_; j++) { if (j < n) { dC[j * lddc + tx] = rC[j]; } } } else { #pragma unroll for (int j = 0; j < NB_; j++) { dC[j * lddc + tx] = rC[j]; } } } //////////////////////////////////////////////////////////////////////////////// // read A (no-trans) from global to reg. // A is (P_ x Q_) // 1D thread config. with (Mx1) threads // no sync at the end of the function template static __device__ __inline__ void read_A_notrans_g2r_1D_nosync(const int tx, const T *dA, int ldda, T *sA, int slda, T rA[Q_]) { #pragma unroll for (int j = 0; j < Q_; j++) { rA[j] = dA(tx, j); } } //////////////////////////////////////////////////////////////////////////////// // read A (no-trans) from global to reg. // A is (P_ x Q_) // 1D thread config. with (Mx1) threads // no sync at the end of the function template static __device__ __inline__ void read_A_trans_g2r_1D_nosync(const int tx, const int ty, const T *dA, int ldda, T *sA, int slda, T rA[Q_]) { int ix = 0; const int nTH = P_ * MAGMA_NONTENSOR_BASIS_NTCOL(P_); const int tid = ty * blockDim.x + tx; #pragma unroll for (ix = 0; ix < (Q_ * P_) - nTH; ix += nTH) { sA[ix + tid] = dA[ix + tid]; } if (tid < ((Q_ * P_) - ix)) { sA[ix + tid] = dA[ix + tid]; } __syncthreads(); #pragma unroll for (int j = 0; j < Q_; j++) { rA[j] = sA[tx * slda + j]; } } //////////////////////////////////////////////////////////////////////////////// // read B from global to shared // B is (Q_ x NB_) // 1D thread config. with (Mx1) threads // no sync at the end of the function template static __device__ __inline__ void read_B_g2s_1D_nosync(const int tx, int n, const T *dB, int lddb, T *sB, int sldb) { if (n != NB_) { for (int i = 0; i < (Q_ * n) - P_; i += P_) { sB[i + tx] = dB[i + tx]; } } else { #pragma unroll for (int i = 0; i < (Q_ * NB_) - P_; i += P_) { sB[i + tx] = dB[i + tx]; } } // cleanup for B const int stride = MAGMA_ROUNDUP(Q_ * n - P_, P_); if (tx < (Q_ * n) - stride) { sB[stride + tx] = dB[stride + tx]; } } //////////////////////////////////////////////////////////////////////////////// // multiply C = AxB using 1D threads in Mx1 config // A (MxK) in reg., one row per thread // B (KxNB) in shared memory // C in registers -- one row per thread // no sync at the end of the function template static __device__ __inline__ void mul_rAsBrC_1D_nosync(const int tx, const T &alpha, T rA[Q_], T *sB, int sldb, T rC[NB_]) { T rB[Q_] = {0}; #pragma unroll for (int i = 0; i < NB_; i++) { #pragma unroll for (int k = 0; k < Q_; k++) { rB[k] = sB[i * sldb + k]; } T rTmp = 0; #pragma unroll for (int k = 0; k < Q_; k++) { rTmp += rA[k] * rB[k]; } rC[i] += alpha * rTmp; } } #undef dA #undef sA #undef dB #undef sB #endif // CEED_MAGMA_COMMON_NONTENSOR_H