xref: /libCEED/include/ceed/jit-source/magma/magma-common-nontensor.h (revision 715f9ba89a309f24226005ca1fbb9f59fe9eac68)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for MAGMA backend common non-tensor basis definitions
10 #ifndef CEED_MAGMA_COMMON_NONTENSOR_H
11 #define CEED_MAGMA_COMMON_NONTENSOR_H
12 
13 #include "magma-common-defs.h"
14 
15 ////////////////////////////////////////////////////////////////////////////////
16 // read A (no-trans) from global to reg.
17 // A is (P x Q)
18 // 1D thread config. with (P x 1) threads
19 // no sync at the end of the function
20 template <typename T, int P, int Q, int NB>
21 static __device__ __inline__ void read_A_notrans_g2r_1D_nosync(const int tx, const T *dA, int ldda, T *sA, int slda, T rA[Q]) {
22 #pragma unroll
23   for (int j = 0; j < Q; j++) {
24     rA[j] = dA[j * ldda + tx];
25   }
26 }
27 
28 ////////////////////////////////////////////////////////////////////////////////
29 // read A (trans) from global to reg.
30 // A is (P x Q)
31 // 1D thread config. with (P x 1) threads
32 // no sync at the end of the function
33 template <typename T, int P, int Q, int NB>
34 static __device__ __inline__ void read_A_trans_g2r_1D_nosync(const int tx, const int ty, const T *dA, int ldda, T *sA, int slda, T rA[Q]) {
35   const int nTH = MAGMA_BASIS_BOUNDS(P, MAGMA_MAXTHREADS_1D);
36   const int tid = ty * blockDim.x + tx;
37   int       i;
38 
39 #pragma unroll
40   for (i = 0; i < (Q * P) - nTH; i += nTH) {
41     sA[i + tid] = dA[i + tid];
42   }
43   if (tid < ((Q * P) - i)) {
44     sA[i + tid] = dA[i + tid];
45   }
46   __syncthreads();
47 
48 #pragma unroll
49   for (int j = 0; j < Q; j++) {
50     rA[j] = sA[tx * slda + j];
51   }
52 }
53 
54 ////////////////////////////////////////////////////////////////////////////////
55 // read B from global to shared
56 // B is (Q x NB)
57 // 1D thread config. with (P x 1) threads
58 // no sync at the end of the function
59 template <typename T, int P, int Q, int NB>
60 static __device__ __inline__ void read_B_g2s_1D_nosync(const int tx, const int n, const T *dB, int lddb, T *sB, int sldb) {
61   if (n != NB) {
62     for (int i = 0; i < (Q * n) - P; i += P) {
63       sB[i + tx] = dB[i + tx];
64     }
65   } else {
66 #pragma unroll
67     for (int i = 0; i < (Q * NB) - P; i += P) {
68       sB[i + tx] = dB[i + tx];
69     }
70   }
71 
72   // cleanup for B
73   const int stride = MAGMA_ROUNDUP(Q * n - P, P);
74   if (tx < (Q * n) - stride) {
75     sB[stride + tx] = dB[stride + tx];
76   }
77 }
78 
79 ////////////////////////////////////////////////////////////////////////////////
80 // write C from reg. to global
81 // C is (P x NB)
82 // 1D thread config. with (P x 1) threads
83 // no sync at the end of the function
84 template <typename T, int P, int Q, int NB>
85 static __device__ __inline__ void write_C_r2g_1D_nosync(const int tx, const int n, T rC[NB], T *dC, int lddc) {
86   if (n != NB) {
87 #pragma unroll
88     for (int j = 0; j < NB; j++) {
89       if (j < n) {
90         dC[j * lddc + tx] = rC[j];
91       }
92     }
93   } else {
94 #pragma unroll
95     for (int j = 0; j < NB; j++) {
96       dC[j * lddc + tx] = rC[j];
97     }
98   }
99 }
100 
101 ////////////////////////////////////////////////////////////////////////////////
102 // multiply C = A x B using 1D threads in P x 1 config
103 // A (P x Q)  in reg., one row per thread
104 // B (Q x NB) in shared memory
105 // C in registers -- one row per thread
106 // no sync at the end of the function
107 template <typename T, int P, int Q, int NB>
108 static __device__ __inline__ void mul_rAsBrC_1D_nosync(const int tx, T rA[Q], T *sB, int sldb, T rC[NB]) {
109   T rB[Q];
110 #pragma unroll
111   for (int i = 0; i < NB; i++) {
112 #pragma unroll
113     for (int k = 0; k < Q; k++) {
114       rB[k] = sB[i * sldb + k];
115     }
116     rC[i] = 0.0;
117 #pragma unroll
118     for (int k = 0; k < Q; k++) {
119       rC[i] += rA[k] * rB[k];
120     }
121   }
122 }
123 
124 ////////////////////////////////////////////////////////////////////////////////
125 // multiply C += A x B using 1D threads in P x 1 config
126 // A (P x Q)  in reg., one row per thread
127 // B (Q x NB) in shared memory
128 // C in registers -- one row per thread
129 // no sync at the end of the function
130 template <typename T, int P, int Q, int NB>
131 static __device__ __inline__ void addmul_rAsBrC_1D_nosync(const int tx, T rA[Q], T *sB, int sldb, T rC[NB]) {
132   T rB[Q];
133 #pragma unroll
134   for (int i = 0; i < NB; i++) {
135 #pragma unroll
136     for (int k = 0; k < Q; k++) {
137       rB[k] = sB[i * sldb + k];
138     }
139 #pragma unroll
140     for (int k = 0; k < Q; k++) {
141       rC[i] += rA[k] * rB[k];
142     }
143   }
144 }
145 
146 #endif  // CEED_MAGMA_COMMON_NONTENSOR_H
147