xref: /libCEED/include/ceed/jit-source/magma/magma-common-nontensor.h (revision d4cc18453651bd0f94c1a2e078b2646a92dafdcc)
1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for MAGMA backend common non-tensor basis definitions
10 #pragma once
11 
12 #include "magma-common-defs.h"
13 
14 ////////////////////////////////////////////////////////////////////////////////
15 // read A (no-trans) from global to reg.
16 // A is (P x Q)
17 // 2D thread config. with (P x BY) threads
18 // no sync at the end of the function
19 template <typename T, int P, int Q, int BY>
read_A_notrans_g2r_1D_nosync(const int tx,const int ty,const T * dA,T * sA,T rA[Q])20 static __device__ __inline__ void read_A_notrans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) {
21   const int tid = ty * P + tx;
22   int       i;
23 
24 #pragma unroll
25   for (i = 0; i < P * Q - P * BY; i += P * BY) {
26     sA[i + tid] = dA[i + tid];
27   }
28   if (i + tid < P * Q) {
29     sA[i + tid] = dA[i + tid];
30   }
31   __syncthreads();
32 
33 #pragma unroll
34   for (int j = 0; j < Q; j++) {
35     rA[j] = sA[j * P + tx];
36   }
37 }
38 
39 ////////////////////////////////////////////////////////////////////////////////
40 // read A (trans) from global to reg.
41 // A is (P x Q)
42 // 2D thread config. with (P x BY) threads
43 // no sync at the end of the function
44 template <typename T, int P, int Q, int BY>
read_A_trans_g2r_1D_nosync(const int tx,const int ty,const T * dA,T * sA,T rA[Q])45 static __device__ __inline__ void read_A_trans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) {
46   const int tid = ty * P + tx;
47   int       i;
48 
49 #pragma unroll
50   for (i = 0; i < P * Q - P * BY; i += P * BY) {
51     sA[i + tid] = dA[i + tid];
52   }
53   if (i + tid < P * Q) {
54     sA[i + tid] = dA[i + tid];
55   }
56   __syncthreads();
57 
58 #pragma unroll
59   for (int j = 0; j < Q; j++) {
60     rA[j] = sA[tx * Q + j];
61   }
62 }
63 
64 ////////////////////////////////////////////////////////////////////////////////
65 // read B from global to shared
66 // B is (Q x NB)
67 // 1D thread config. with (P x 1) threads
68 // no sync at the end of the function
69 template <typename T, int P, int Q, int NB>
read_B_g2s_1D_nosync(const int tx,const int n,const T * dB,T * sB)70 static __device__ __inline__ void read_B_g2s_1D_nosync(const int tx, const int n, const T *dB, T *sB) {
71   int i;
72 
73   if (n != NB) {
74     for (i = 0; i < Q * n - P; i += P) {
75       sB[i + tx] = dB[i + tx];
76     }
77   } else {
78 #pragma unroll
79     for (i = 0; i < Q * NB - P; i += P) {
80       sB[i + tx] = dB[i + tx];
81     }
82   }
83   if (i + tx < Q * n) {
84     sB[i + tx] = dB[i + tx];
85   }
86 }
87 
88 ////////////////////////////////////////////////////////////////////////////////
89 // write C from reg. to global
90 // C is (P x NB)
91 // 1D thread config. with (P x 1) threads
92 // no sync at the end of the function
93 template <typename T, int P, int Q, int NB>
write_C_r2g_1D_nosync(const int tx,const int n,T rC[NB],T * dC)94 static __device__ __inline__ void write_C_r2g_1D_nosync(const int tx, const int n, T rC[NB], T *dC) {
95   if (n != NB) {
96     for (int i = 0; i < n; i++) {
97       dC[i * P + tx] = rC[i];
98     }
99   } else {
100 #pragma unroll
101     for (int i = 0; i < NB; i++) {
102       dC[i * P + tx] = rC[i];
103     }
104   }
105 }
106 
107 ////////////////////////////////////////////////////////////////////////////////
108 // sum into C from reg. to global
109 // C is (P x NB)
110 // 1D thread config. with (P x 1) threads
111 // no sync at the end of the function
112 template <typename T, int P, int Q, int NB>
sum_C_r2g_1D_nosync(const int tx,const int n,T rC[NB],T * dC)113 static __device__ __inline__ void sum_C_r2g_1D_nosync(const int tx, const int n, T rC[NB], T *dC) {
114   if (n != NB) {
115     for (int i = 0; i < n; i++) {
116       dC[i * P + tx] += rC[i];
117     }
118   } else {
119 #pragma unroll
120     for (int i = 0; i < NB; i++) {
121       dC[i * P + tx] += rC[i];
122     }
123   }
124 }
125 
126 ////////////////////////////////////////////////////////////////////////////////
127 // multiply C = A x B using 1D threads in P x 1 config
128 // A (P x Q)  in reg., one row per thread
129 // B (Q x NB) in shared memory
130 // C in registers -- one row per thread
131 // no sync at the end of the function
132 template <typename T, int P, int Q, int NB>
mul_rAsBrC_1D_nosync(T rA[Q],T * sB,T rC[NB])133 static __device__ __inline__ void mul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) {
134   T rB[Q];
135 
136 #pragma unroll
137   for (int i = 0; i < NB; i++) {
138 #pragma unroll
139     for (int j = 0; j < Q; j++) {
140       rB[j] = sB[i * Q + j];
141     }
142     rC[i] = 0.0;
143 #pragma unroll
144     for (int j = 0; j < Q; j++) {
145       rC[i] += rA[j] * rB[j];
146     }
147   }
148 }
149 
150 ////////////////////////////////////////////////////////////////////////////////
151 // multiply C += A x B using 1D threads in P x 1 config
152 // A (P x Q)  in reg., one row per thread
153 // B (Q x NB) in shared memory
154 // C in registers -- one row per thread
155 // no sync at the end of the function
156 template <typename T, int P, int Q, int NB>
addmul_rAsBrC_1D_nosync(T rA[Q],T * sB,T rC[NB])157 static __device__ __inline__ void addmul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) {
158   T rB[Q];
159 
160 #pragma unroll
161   for (int i = 0; i < NB; i++) {
162 #pragma unroll
163     for (int j = 0; j < Q; j++) {
164       rB[j] = sB[i * Q + j];
165     }
166 #pragma unroll
167     for (int j = 0; j < Q; j++) {
168       rC[i] += rA[j] * rB[j];
169     }
170   }
171 }
172