xref: /libCEED/rust/libceed-sys/c-src/include/ceed/jit-source/magma/magma-common-nontensor.h (revision f80f4a748154eed4bc661c135f695b92b1bc45b9)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 #ifndef CEED_MAGMA_COMMON_NONTENSOR_H
9 #define CEED_MAGMA_COMMON_NONTENSOR_H
10 
11 #define NONTENSOR_MAX_THREADS (128)
12 
13 #ifndef MAGMA_DEVICE_SHARED
14 #define MAGMA_DEVICE_SHARED
15 #ifdef CEED_MAGMA_USE_HIP
16 #define MAGMA_DEVICE_SHARED(type, name) HIP_DYNAMIC_SHARED(type, name)
17 #else
18 #define MAGMA_DEVICE_SHARED(type, name) extern __shared__ type name[];
19 #endif  // CEED_MAGMA_USE_HIP
20 #endif  // MAGMA_DEVICE_SHARED
21 
22 #define MAGMA_NONTENSOR_BASIS_NTCOL(N) (MAGMA_MAX(1, (NONTENSOR_MAX_THREADS / (N))))
23 
24 #define dA(i, j) dA[(j)*ldda + (i)]
25 #define sA(i, j) sA[(j)*slda + (i)]
26 #define dB(i, j) dB[(j)*lddb + (i)]
27 #define sB(i, j) sB[(j)*sldb + (i)]
28 
29 ////////////////////////////////////////////////////////////////////////////////
30 // read C from global to reg.
31 // C is (P_ x NB_)
32 // 1D thread config. with (Mx1) threads
33 // no sync at the end of the function
34 template <typename T, int P_, int NB_, int Q_>
35 static __device__ __inline__ void read_C_g2r_1D_nosync(const int tx, const int n, T *dC, int lddc, const T &beta, T rC[NB_]) {
36   if (n != NB_) {
37 #pragma unroll
38     for (int j = 0; j < NB_; j++) {
39       rC[j] = (j < n) ? beta * dC[j * lddc + tx] : 0;
40     }
41   } else {
42 #pragma unroll
43     for (int j = 0; j < NB_; j++) {
44       rC[j] = beta * dC[j * lddc + tx];
45     }
46   }
47 }
48 
49 ////////////////////////////////////////////////////////////////////////////////
50 // write C from reg. to global
51 // C is (P_ x NB_)
52 // 1D thread config. with (Mx1) threads
53 // no sync at the end of the function
54 template <typename T, int P_, int NB_, int Q_>
55 static __device__ __inline__ void write_C_r2g_1D_nosync(const int tx, const int n, T rC[NB_], T *dC, int lddc) {
56   if (n != NB_) {
57 #pragma unroll
58     for (int j = 0; j < NB_; j++) {
59       if (j < n) {
60         dC[j * lddc + tx] = rC[j];
61       }
62     }
63   } else {
64 #pragma unroll
65     for (int j = 0; j < NB_; j++) {
66       dC[j * lddc + tx] = rC[j];
67     }
68   }
69 }
70 
71 ////////////////////////////////////////////////////////////////////////////////
72 // read A (no-trans) from global to reg.
73 // A is (P_ x Q_)
74 // 1D thread config. with (Mx1) threads
75 // no sync at the end of the function
76 template <typename T, int P_, int NB_, int Q_>
77 static __device__ __inline__ void read_A_notrans_g2r_1D_nosync(const int tx, const T *dA, int ldda, T *sA, int slda, T rA[Q_]) {
78 #pragma unroll
79   for (int j = 0; j < Q_; j++) {
80     rA[j] = dA(tx, j);
81   }
82 }
83 
84 ////////////////////////////////////////////////////////////////////////////////
85 // read A (no-trans) from global to reg.
86 // A is (P_ x Q_)
87 // 1D thread config. with (Mx1) threads
88 // no sync at the end of the function
89 template <typename T, int P_, int NB_, int Q_>
90 static __device__ __inline__ void read_A_trans_g2r_1D_nosync(const int tx, const int ty, const T *dA, int ldda, T *sA, int slda, T rA[Q_]) {
91   int       ix  = 0;
92   const int nTH = P_ * MAGMA_NONTENSOR_BASIS_NTCOL(P_);
93   const int tid = ty * blockDim.x + tx;
94 
95 #pragma unroll
96   for (ix = 0; ix < (Q_ * P_) - nTH; ix += nTH) {
97     sA[ix + tid] = dA[ix + tid];
98   }
99 
100   if (tid < ((Q_ * P_) - ix)) {
101     sA[ix + tid] = dA[ix + tid];
102   }
103   __syncthreads();
104 
105 #pragma unroll
106   for (int j = 0; j < Q_; j++) {
107     rA[j] = sA[tx * slda + j];
108   }
109 }
110 
111 ////////////////////////////////////////////////////////////////////////////////
112 // read B from global to shared
113 // B is (Q_ x NB_)
114 // 1D thread config. with (Mx1) threads
115 // no sync at the end of the function
116 template <typename T, int P_, int NB_, int Q_>
117 static __device__ __inline__ void read_B_g2s_1D_nosync(const int tx, int n, const T *dB, int lddb, T *sB, int sldb) {
118   if (n != NB_) {
119     for (int i = 0; i < (Q_ * n) - P_; i += P_) {
120       sB[i + tx] = dB[i + tx];
121     }
122   } else {
123 #pragma unroll
124     for (int i = 0; i < (Q_ * NB_) - P_; i += P_) {
125       sB[i + tx] = dB[i + tx];
126     }
127   }
128 
129   // cleanup for B
130   const int stride = MAGMA_ROUNDUP(Q_ * n - P_, P_);
131   if (tx < (Q_ * n) - stride) {
132     sB[stride + tx] = dB[stride + tx];
133   }
134 }
135 
136 ////////////////////////////////////////////////////////////////////////////////
137 // multiply C = AxB using 1D threads in Mx1 config
138 // A (MxK)  in reg., one row per thread
139 // B (KxNB) in shared memory
140 // C in registers -- one row per thread
141 // no sync at the end of the function
142 template <typename T, int P_, int NB_, int Q_>
143 static __device__ __inline__ void mul_rAsBrC_1D_nosync(const int tx, const T &alpha, T rA[Q_], T *sB, int sldb, T rC[NB_]) {
144   T rB[Q_] = {0};
145 #pragma unroll
146   for (int i = 0; i < NB_; i++) {
147 #pragma unroll
148     for (int k = 0; k < Q_; k++) {
149       rB[k] = sB[i * sldb + k];
150     }
151 
152     T rTmp = 0;
153 #pragma unroll
154     for (int k = 0; k < Q_; k++) {
155       rTmp += rA[k] * rB[k];
156     }
157     rC[i] += alpha * rTmp;
158   }
159 }
160 
161 #undef dA
162 #undef sA
163 #undef dB
164 #undef sB
165 
166 #endif  // CEED_MAGMA_COMMON_NONTENSOR_H
167