1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED: http://github.com/ceed
7
8 /// @file
9 /// Internal header for MAGMA backend common non-tensor basis definitions
10 #pragma once
11
12 #include "magma-common-defs.h"
13
14 ////////////////////////////////////////////////////////////////////////////////
15 // read A (no-trans) from global to reg.
16 // A is (P x Q)
17 // 2D thread config. with (P x BY) threads
18 // no sync at the end of the function
19 template <typename T, int P, int Q, int BY>
read_A_notrans_g2r_1D_nosync(const int tx,const int ty,const T * dA,T * sA,T rA[Q])20 static __device__ __inline__ void read_A_notrans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) {
21 const int tid = ty * P + tx;
22 int i;
23
24 #pragma unroll
25 for (i = 0; i < P * Q - P * BY; i += P * BY) {
26 sA[i + tid] = dA[i + tid];
27 }
28 if (i + tid < P * Q) {
29 sA[i + tid] = dA[i + tid];
30 }
31 __syncthreads();
32
33 #pragma unroll
34 for (int j = 0; j < Q; j++) {
35 rA[j] = sA[j * P + tx];
36 }
37 }
38
39 ////////////////////////////////////////////////////////////////////////////////
40 // read A (trans) from global to reg.
41 // A is (P x Q)
42 // 2D thread config. with (P x BY) threads
43 // no sync at the end of the function
44 template <typename T, int P, int Q, int BY>
read_A_trans_g2r_1D_nosync(const int tx,const int ty,const T * dA,T * sA,T rA[Q])45 static __device__ __inline__ void read_A_trans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) {
46 const int tid = ty * P + tx;
47 int i;
48
49 #pragma unroll
50 for (i = 0; i < P * Q - P * BY; i += P * BY) {
51 sA[i + tid] = dA[i + tid];
52 }
53 if (i + tid < P * Q) {
54 sA[i + tid] = dA[i + tid];
55 }
56 __syncthreads();
57
58 #pragma unroll
59 for (int j = 0; j < Q; j++) {
60 rA[j] = sA[tx * Q + j];
61 }
62 }
63
64 ////////////////////////////////////////////////////////////////////////////////
65 // read B from global to shared
66 // B is (Q x NB)
67 // 1D thread config. with (P x 1) threads
68 // no sync at the end of the function
69 template <typename T, int P, int Q, int NB>
read_B_g2s_1D_nosync(const int tx,const int n,const T * dB,T * sB)70 static __device__ __inline__ void read_B_g2s_1D_nosync(const int tx, const int n, const T *dB, T *sB) {
71 int i;
72
73 if (n != NB) {
74 for (i = 0; i < Q * n - P; i += P) {
75 sB[i + tx] = dB[i + tx];
76 }
77 } else {
78 #pragma unroll
79 for (i = 0; i < Q * NB - P; i += P) {
80 sB[i + tx] = dB[i + tx];
81 }
82 }
83 if (i + tx < Q * n) {
84 sB[i + tx] = dB[i + tx];
85 }
86 }
87
88 ////////////////////////////////////////////////////////////////////////////////
89 // write C from reg. to global
90 // C is (P x NB)
91 // 1D thread config. with (P x 1) threads
92 // no sync at the end of the function
93 template <typename T, int P, int Q, int NB>
write_C_r2g_1D_nosync(const int tx,const int n,T rC[NB],T * dC)94 static __device__ __inline__ void write_C_r2g_1D_nosync(const int tx, const int n, T rC[NB], T *dC) {
95 if (n != NB) {
96 for (int i = 0; i < n; i++) {
97 dC[i * P + tx] = rC[i];
98 }
99 } else {
100 #pragma unroll
101 for (int i = 0; i < NB; i++) {
102 dC[i * P + tx] = rC[i];
103 }
104 }
105 }
106
107 ////////////////////////////////////////////////////////////////////////////////
108 // sum into C from reg. to global
109 // C is (P x NB)
110 // 1D thread config. with (P x 1) threads
111 // no sync at the end of the function
112 template <typename T, int P, int Q, int NB>
sum_C_r2g_1D_nosync(const int tx,const int n,T rC[NB],T * dC)113 static __device__ __inline__ void sum_C_r2g_1D_nosync(const int tx, const int n, T rC[NB], T *dC) {
114 if (n != NB) {
115 for (int i = 0; i < n; i++) {
116 dC[i * P + tx] += rC[i];
117 }
118 } else {
119 #pragma unroll
120 for (int i = 0; i < NB; i++) {
121 dC[i * P + tx] += rC[i];
122 }
123 }
124 }
125
126 ////////////////////////////////////////////////////////////////////////////////
127 // multiply C = A x B using 1D threads in P x 1 config
128 // A (P x Q) in reg., one row per thread
129 // B (Q x NB) in shared memory
130 // C in registers -- one row per thread
131 // no sync at the end of the function
132 template <typename T, int P, int Q, int NB>
mul_rAsBrC_1D_nosync(T rA[Q],T * sB,T rC[NB])133 static __device__ __inline__ void mul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) {
134 T rB[Q];
135
136 #pragma unroll
137 for (int i = 0; i < NB; i++) {
138 #pragma unroll
139 for (int j = 0; j < Q; j++) {
140 rB[j] = sB[i * Q + j];
141 }
142 rC[i] = 0.0;
143 #pragma unroll
144 for (int j = 0; j < Q; j++) {
145 rC[i] += rA[j] * rB[j];
146 }
147 }
148 }
149
150 ////////////////////////////////////////////////////////////////////////////////
151 // multiply C += A x B using 1D threads in P x 1 config
152 // A (P x Q) in reg., one row per thread
153 // B (Q x NB) in shared memory
154 // C in registers -- one row per thread
155 // no sync at the end of the function
156 template <typename T, int P, int Q, int NB>
addmul_rAsBrC_1D_nosync(T rA[Q],T * sB,T rC[NB])157 static __device__ __inline__ void addmul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) {
158 T rB[Q];
159
160 #pragma unroll
161 for (int i = 0; i < NB; i++) {
162 #pragma unroll
163 for (int j = 0; j < Q; j++) {
164 rB[j] = sB[i * Q + j];
165 }
166 #pragma unroll
167 for (int j = 0; j < Q; j++) {
168 rC[i] += rA[j] * rB[j];
169 }
170 }
171 }
172