xref: /libCEED/include/ceed/jit-source/magma/magma-common-nontensor.h (revision 5aed82e4fa97acf4ba24a7f10a35f5303a6798e0)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for MAGMA backend common non-tensor basis definitions
10 #ifndef CEED_MAGMA_COMMON_NONTENSOR_H
11 #define CEED_MAGMA_COMMON_NONTENSOR_H
12 
13 #include "magma-common-defs.h"
14 
15 ////////////////////////////////////////////////////////////////////////////////
16 // read A (no-trans) from global to reg.
17 // A is (P x Q)
18 // 2D thread config. with (P x BY) threads
19 // no sync at the end of the function
20 template <typename T, int P, int Q, int BY>
21 static __device__ __inline__ void read_A_notrans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) {
22   const int tid = ty * P + tx;
23   int       i;
24 
25 #pragma unroll
26   for (i = 0; i < P * Q - P * BY; i += P * BY) {
27     sA[i + tid] = dA[i + tid];
28   }
29   if (i + tid < P * Q) {
30     sA[i + tid] = dA[i + tid];
31   }
32   __syncthreads();
33 
34 #pragma unroll
35   for (int j = 0; j < Q; j++) {
36     rA[j] = sA[j * P + tx];
37   }
38 }
39 
40 ////////////////////////////////////////////////////////////////////////////////
41 // read A (trans) from global to reg.
42 // A is (P x Q)
43 // 2D thread config. with (P x BY) threads
44 // no sync at the end of the function
45 template <typename T, int P, int Q, int BY>
46 static __device__ __inline__ void read_A_trans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) {
47   const int tid = ty * P + tx;
48   int       i;
49 
50 #pragma unroll
51   for (i = 0; i < P * Q - P * BY; i += P * BY) {
52     sA[i + tid] = dA[i + tid];
53   }
54   if (i + tid < P * Q) {
55     sA[i + tid] = dA[i + tid];
56   }
57   __syncthreads();
58 
59 #pragma unroll
60   for (int j = 0; j < Q; j++) {
61     rA[j] = sA[tx * Q + j];
62   }
63 }
64 
65 ////////////////////////////////////////////////////////////////////////////////
66 // read B from global to shared
67 // B is (Q x NB)
68 // 1D thread config. with (P x 1) threads
69 // no sync at the end of the function
70 template <typename T, int P, int Q, int NB>
71 static __device__ __inline__ void read_B_g2s_1D_nosync(const int tx, const int n, const T *dB, T *sB) {
72   int i;
73 
74   if (n != NB) {
75     for (i = 0; i < Q * n - P; i += P) {
76       sB[i + tx] = dB[i + tx];
77     }
78   } else {
79 #pragma unroll
80     for (i = 0; i < Q * NB - P; i += P) {
81       sB[i + tx] = dB[i + tx];
82     }
83   }
84   if (i + tx < Q * n) {
85     sB[i + tx] = dB[i + tx];
86   }
87 }
88 
89 ////////////////////////////////////////////////////////////////////////////////
90 // write C from reg. to global
91 // C is (P x NB)
92 // 1D thread config. with (P x 1) threads
93 // no sync at the end of the function
94 template <typename T, int P, int Q, int NB>
95 static __device__ __inline__ void write_C_r2g_1D_nosync(const int tx, const int n, T rC[NB], T *dC) {
96   if (n != NB) {
97     for (int i = 0; i < n; i++) {
98       dC[i * P + tx] = rC[i];
99     }
100   } else {
101 #pragma unroll
102     for (int i = 0; i < NB; i++) {
103       dC[i * P + tx] = rC[i];
104     }
105   }
106 }
107 
108 ////////////////////////////////////////////////////////////////////////////////
109 // multiply C = A x B using 1D threads in P x 1 config
110 // A (P x Q)  in reg., one row per thread
111 // B (Q x NB) in shared memory
112 // C in registers -- one row per thread
113 // no sync at the end of the function
114 template <typename T, int P, int Q, int NB>
115 static __device__ __inline__ void mul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) {
116   T rB[Q];
117 
118 #pragma unroll
119   for (int i = 0; i < NB; i++) {
120 #pragma unroll
121     for (int j = 0; j < Q; j++) {
122       rB[j] = sB[i * Q + j];
123     }
124     rC[i] = 0.0;
125 #pragma unroll
126     for (int j = 0; j < Q; j++) {
127       rC[i] += rA[j] * rB[j];
128     }
129   }
130 }
131 
132 ////////////////////////////////////////////////////////////////////////////////
133 // multiply C += A x B using 1D threads in P x 1 config
134 // A (P x Q)  in reg., one row per thread
135 // B (Q x NB) in shared memory
136 // C in registers -- one row per thread
137 // no sync at the end of the function
138 template <typename T, int P, int Q, int NB>
139 static __device__ __inline__ void addmul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) {
140   T rB[Q];
141 
142 #pragma unroll
143   for (int i = 0; i < NB; i++) {
144 #pragma unroll
145     for (int j = 0; j < Q; j++) {
146       rB[j] = sB[i * Q + j];
147     }
148 #pragma unroll
149     for (int j = 0; j < Q; j++) {
150       rC[i] += rA[j] * rB[j];
151     }
152   }
153 }
154 
155 #endif  // CEED_MAGMA_COMMON_NONTENSOR_H
156