xref: /libCEED/include/ceed/jit-source/magma/magma-common-nontensor.h (revision c2cc34eeef4dafbcd1d00ef770c45974c5ea3da2)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for MAGMA backend common non-tensor basis definitions
10 #pragma once
11 
12 #include "magma-common-defs.h"
13 
14 ////////////////////////////////////////////////////////////////////////////////
15 // read A (no-trans) from global to reg.
16 // A is (P x Q)
17 // 2D thread config. with (P x BY) threads
18 // no sync at the end of the function
19 template <typename T, int P, int Q, int BY>
20 static __device__ __inline__ void read_A_notrans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) {
21   const int tid = ty * P + tx;
22   int       i;
23 
24 #pragma unroll
25   for (i = 0; i < P * Q - P * BY; i += P * BY) {
26     sA[i + tid] = dA[i + tid];
27   }
28   if (i + tid < P * Q) {
29     sA[i + tid] = dA[i + tid];
30   }
31   __syncthreads();
32 
33 #pragma unroll
34   for (int j = 0; j < Q; j++) {
35     rA[j] = sA[j * P + tx];
36   }
37 }
38 
39 ////////////////////////////////////////////////////////////////////////////////
40 // read A (trans) from global to reg.
41 // A is (P x Q)
42 // 2D thread config. with (P x BY) threads
43 // no sync at the end of the function
44 template <typename T, int P, int Q, int BY>
45 static __device__ __inline__ void read_A_trans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) {
46   const int tid = ty * P + tx;
47   int       i;
48 
49 #pragma unroll
50   for (i = 0; i < P * Q - P * BY; i += P * BY) {
51     sA[i + tid] = dA[i + tid];
52   }
53   if (i + tid < P * Q) {
54     sA[i + tid] = dA[i + tid];
55   }
56   __syncthreads();
57 
58 #pragma unroll
59   for (int j = 0; j < Q; j++) {
60     rA[j] = sA[tx * Q + j];
61   }
62 }
63 
64 ////////////////////////////////////////////////////////////////////////////////
65 // read B from global to shared
66 // B is (Q x NB)
67 // 1D thread config. with (P x 1) threads
68 // no sync at the end of the function
69 template <typename T, int P, int Q, int NB>
70 static __device__ __inline__ void read_B_g2s_1D_nosync(const int tx, const int n, const T *dB, T *sB) {
71   int i;
72 
73   if (n != NB) {
74     for (i = 0; i < Q * n - P; i += P) {
75       sB[i + tx] = dB[i + tx];
76     }
77   } else {
78 #pragma unroll
79     for (i = 0; i < Q * NB - P; i += P) {
80       sB[i + tx] = dB[i + tx];
81     }
82   }
83   if (i + tx < Q * n) {
84     sB[i + tx] = dB[i + tx];
85   }
86 }
87 
88 ////////////////////////////////////////////////////////////////////////////////
89 // write C from reg. to global
90 // C is (P x NB)
91 // 1D thread config. with (P x 1) threads
92 // no sync at the end of the function
93 template <typename T, int P, int Q, int NB>
94 static __device__ __inline__ void write_C_r2g_1D_nosync(const int tx, const int n, T rC[NB], T *dC) {
95   if (n != NB) {
96     for (int i = 0; i < n; i++) {
97       dC[i * P + tx] = rC[i];
98     }
99   } else {
100 #pragma unroll
101     for (int i = 0; i < NB; i++) {
102       dC[i * P + tx] = rC[i];
103     }
104   }
105 }
106 
107 ////////////////////////////////////////////////////////////////////////////////
108 // multiply C = A x B using 1D threads in P x 1 config
109 // A (P x Q)  in reg., one row per thread
110 // B (Q x NB) in shared memory
111 // C in registers -- one row per thread
112 // no sync at the end of the function
113 template <typename T, int P, int Q, int NB>
114 static __device__ __inline__ void mul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) {
115   T rB[Q];
116 
117 #pragma unroll
118   for (int i = 0; i < NB; i++) {
119 #pragma unroll
120     for (int j = 0; j < Q; j++) {
121       rB[j] = sB[i * Q + j];
122     }
123     rC[i] = 0.0;
124 #pragma unroll
125     for (int j = 0; j < Q; j++) {
126       rC[i] += rA[j] * rB[j];
127     }
128   }
129 }
130 
131 ////////////////////////////////////////////////////////////////////////////////
132 // multiply C += A x B using 1D threads in P x 1 config
133 // A (P x Q)  in reg., one row per thread
134 // B (Q x NB) in shared memory
135 // C in registers -- one row per thread
136 // no sync at the end of the function
137 template <typename T, int P, int Q, int NB>
138 static __device__ __inline__ void addmul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) {
139   T rB[Q];
140 
141 #pragma unroll
142   for (int i = 0; i < NB; i++) {
143 #pragma unroll
144     for (int j = 0; j < Q; j++) {
145       rB[j] = sB[i * Q + j];
146     }
147 #pragma unroll
148     for (int j = 0; j < Q; j++) {
149       rC[i] += rA[j] * rB[j];
150     }
151   }
152 }
153