xref: /libCEED/rust/libceed-sys/c-src/include/ceed/jit-source/magma/magma-common-nontensor.h (revision 9d15e85b4f78ffb2d2860753c87a3b1789cc3bb6)
1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for MAGMA backend common non-tensor basis definitions
10 #ifndef CEED_MAGMA_COMMON_NONTENSOR_H
11 #define CEED_MAGMA_COMMON_NONTENSOR_H
12 
13 #include "magma-common-defs.h"
14 
15 ////////////////////////////////////////////////////////////////////////////////
16 // read A (no-trans) from global to reg.
17 // A is (P x Q)
18 // 1D thread config. with (P x 1) threads
19 // no sync at the end of the function
20 template <typename T, int P, int Q, int NB>
21 static __device__ __inline__ void read_A_notrans_g2r_1D_nosync(const int tx, const T *dA, T rA[Q]) {
22 #pragma unroll
23   for (int i = 0; i < Q; i++) {
24     rA[i] = dA[i * P + tx];
25   }
26 }
27 
28 ////////////////////////////////////////////////////////////////////////////////
29 // read A (trans) from global to reg.
30 // A is (P x Q)
31 // 2D thread config. with (P x BY) threads
32 // no sync at the end of the function
33 template <typename T, int P, int Q, int BY>
34 static __device__ __inline__ void read_A_trans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) {
35   const int tid = ty * blockDim.x + tx;
36   int       i;
37 
38 #pragma unroll
39   for (i = 0; i < P * Q - P * BY; i += P * BY) {
40     sA[i + tid] = dA[i + tid];
41   }
42   if (i + tid < P * Q) {
43     sA[i + tid] = dA[i + tid];
44   }
45   __syncthreads();
46 
47 #pragma unroll
48   for (int j = 0; j < Q; j++) {
49     rA[j] = sA[tx * Q + j];
50   }
51 }
52 
53 ////////////////////////////////////////////////////////////////////////////////
54 // read B from global to shared
55 // B is (Q x NB)
56 // 1D thread config. with (P x 1) threads
57 // no sync at the end of the function
58 template <typename T, int P, int Q, int NB>
59 static __device__ __inline__ void read_B_g2s_1D_nosync(const int tx, const int n, const T *dB, T *sB) {
60   int i;
61 
62   if (n != NB) {
63     for (i = 0; i < Q * n - P; i += P) {
64       sB[i + tx] = dB[i + tx];
65     }
66   } else {
67 #pragma unroll
68     for (i = 0; i < Q * NB - P; i += P) {
69       sB[i + tx] = dB[i + tx];
70     }
71   }
72   if (i + tx < Q * n) {
73     sB[i + tx] = dB[i + tx];
74   }
75 }
76 
77 ////////////////////////////////////////////////////////////////////////////////
78 // write C from reg. to global
79 // C is (P x NB)
80 // 1D thread config. with (P x 1) threads
81 // no sync at the end of the function
82 template <typename T, int P, int Q, int NB>
83 static __device__ __inline__ void write_C_r2g_1D_nosync(const int tx, const int n, T rC[NB], T *dC) {
84   if (n != NB) {
85     for (int i = 0; i < n; i++) {
86       dC[i * P + tx] = rC[i];
87     }
88   } else {
89 #pragma unroll
90     for (int i = 0; i < NB; i++) {
91       dC[i * P + tx] = rC[i];
92     }
93   }
94 }
95 
96 ////////////////////////////////////////////////////////////////////////////////
97 // multiply C = A x B using 1D threads in P x 1 config
98 // A (P x Q)  in reg., one row per thread
99 // B (Q x NB) in shared memory
100 // C in registers -- one row per thread
101 // no sync at the end of the function
102 template <typename T, int P, int Q, int NB>
103 static __device__ __inline__ void mul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) {
104   T rB[Q];
105 
106 #pragma unroll
107   for (int i = 0; i < NB; i++) {
108 #pragma unroll
109     for (int j = 0; j < Q; j++) {
110       rB[j] = sB[i * Q + j];
111     }
112     rC[i] = 0.0;
113 #pragma unroll
114     for (int j = 0; j < Q; j++) {
115       rC[i] += rA[j] * rB[j];
116     }
117   }
118 }
119 
120 ////////////////////////////////////////////////////////////////////////////////
121 // multiply C += A x B using 1D threads in P x 1 config
122 // A (P x Q)  in reg., one row per thread
123 // B (Q x NB) in shared memory
124 // C in registers -- one row per thread
125 // no sync at the end of the function
126 template <typename T, int P, int Q, int NB>
127 static __device__ __inline__ void addmul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) {
128   T rB[Q];
129 
130 #pragma unroll
131   for (int i = 0; i < NB; i++) {
132 #pragma unroll
133     for (int j = 0; j < Q; j++) {
134       rB[j] = sB[i * Q + j];
135     }
136 #pragma unroll
137     for (int j = 0; j < Q; j++) {
138       rC[i] += rA[j] * rB[j];
139     }
140   }
141 }
142 
143 #endif  // CEED_MAGMA_COMMON_NONTENSOR_H
144