1 // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED: http://github.com/ceed
7
8 #include <ceed.h>
9 #include <ceed/backend.h>
10 #include <immintrin.h>
11 #include <stdbool.h>
12
13 #ifdef CEED_SCALAR_IS_FP64
14 #define rtype __m256d
15 #define loadu _mm256_loadu_pd
16 #define storeu _mm256_storeu_pd
17 #define set _mm256_set_pd
18 #define set1 _mm256_set1_pd
19 // c += a * b
20 #ifdef __FMA__
21 #define fmadd(c, a, b) (c) = _mm256_fmadd_pd((a), (b), (c))
22 #else
23 #define fmadd(c, a, b) (c) += _mm256_mul_pd((a), (b))
24 #endif
25 #else
26 #define rtype __m128
27 #define loadu _mm_loadu_ps
28 #define storeu _mm_storeu_ps
29 #define set _mm_set_ps
30 #define set1 _mm_set1_ps
31 // c += a * b
32 #ifdef __FMA__
33 #define fmadd(c, a, b) (c) = _mm_fmadd_ps((a), (b), (c))
34 #else
35 #define fmadd(c, a, b) (c) += _mm_mul_ps((a), (b))
36 #endif
37 #endif
38
39 //------------------------------------------------------------------------------
40 // Blocked Tensor Contract
41 //------------------------------------------------------------------------------
CeedTensorContract_Avx_Blocked(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v,const CeedInt JJ,const CeedInt CC)42 static inline int CeedTensorContract_Avx_Blocked(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J,
43 const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add,
44 const CeedScalar *restrict u, CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
45 CeedInt t_stride_0 = B, t_stride_1 = 1;
46
47 if (t_mode == CEED_TRANSPOSE) {
48 t_stride_0 = 1;
49 t_stride_1 = J;
50 }
51
52 for (CeedInt a = 0; a < A; a++) {
53 // Blocks of 4 rows
54 for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) {
55 for (CeedInt c = 0; c < (C / CC) * CC; c += CC) {
56 rtype vv[JJ][CC / 4]; // Output tile to be held in registers
57 for (CeedInt jj = 0; jj < JJ; jj++) {
58 for (CeedInt cc = 0; cc < CC / 4; cc++) vv[jj][cc] = loadu(&v[(a * J + j + jj) * C + c + cc * 4]);
59 }
60 for (CeedInt b = 0; b < B; b++) {
61 for (CeedInt jj = 0; jj < JJ; jj++) { // unroll
62 rtype tqv = set1(t[(j + jj) * t_stride_0 + b * t_stride_1]);
63 for (CeedInt cc = 0; cc < CC / 4; cc++) { // unroll
64 fmadd(vv[jj][cc], tqv, loadu(&u[(a * B + b) * C + c + cc * 4]));
65 }
66 }
67 }
68 for (CeedInt jj = 0; jj < JJ; jj++) {
69 for (CeedInt cc = 0; cc < CC / 4; cc++) storeu(&v[(a * J + j + jj) * C + c + cc * 4], vv[jj][cc]);
70 }
71 }
72 }
73 // Remainder of rows
74 const CeedInt j = (J / JJ) * JJ;
75
76 if (j < J) {
77 for (CeedInt c = 0; c < (C / CC) * CC; c += CC) {
78 rtype vv[JJ][CC / 4]; // Output tile to be held in registers
79
80 for (CeedInt jj = 0; jj < J - j; jj++) {
81 for (CeedInt cc = 0; cc < CC / 4; cc++) vv[jj][cc] = loadu(&v[(a * J + j + jj) * C + c + cc * 4]);
82 }
83 for (CeedInt b = 0; b < B; b++) {
84 for (CeedInt jj = 0; jj < J - j; jj++) { // doesn't unroll
85 rtype tqv = set1(t[(j + jj) * t_stride_0 + b * t_stride_1]);
86
87 for (CeedInt cc = 0; cc < CC / 4; cc++) { // unroll
88 fmadd(vv[jj][cc], tqv, loadu(&u[(a * B + b) * C + c + cc * 4]));
89 }
90 }
91 }
92 for (CeedInt jj = 0; jj < J - j; jj++) {
93 for (CeedInt cc = 0; cc < CC / 4; cc++) storeu(&v[(a * J + j + jj) * C + c + cc * 4], vv[jj][cc]);
94 }
95 }
96 }
97 }
98 return CEED_ERROR_SUCCESS;
99 }
100
101 //------------------------------------------------------------------------------
102 // Serial Tensor Contract Remainder
103 //------------------------------------------------------------------------------
CeedTensorContract_Avx_Remainder(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v,const CeedInt JJ,const CeedInt CC)104 static inline int CeedTensorContract_Avx_Remainder(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J,
105 const CeedScalar *restrict t, CeedTransposeMode t_mode, const CeedInt add,
106 const CeedScalar *restrict u, CeedScalar *restrict v, const CeedInt JJ, const CeedInt CC) {
107 CeedInt t_stride_0 = B, t_stride_1 = 1;
108
109 if (t_mode == CEED_TRANSPOSE) {
110 t_stride_0 = 1;
111 t_stride_1 = J;
112 }
113
114 const CeedInt J_break = J % JJ ? (J / JJ) * JJ : (J / JJ - 1) * JJ;
115
116 for (CeedInt a = 0; a < A; a++) {
117 // Blocks of 4 columns
118 for (CeedInt c = (C / CC) * CC; c < C; c += 4) {
119 // Blocks of 4 rows
120 for (CeedInt j = 0; j < J_break; j += JJ) {
121 rtype vv[JJ]; // Output tile to be held in registers
122
123 for (CeedInt jj = 0; jj < JJ; jj++) vv[jj] = loadu(&v[(a * J + j + jj) * C + c]);
124 for (CeedInt b = 0; b < B; b++) {
125 rtype tqu;
126
127 if (C - c == 1) tqu = set(0.0, 0.0, 0.0, u[(a * B + b) * C + c + 0]);
128 else if (C - c == 2) tqu = set(0.0, 0.0, u[(a * B + b) * C + c + 1], u[(a * B + b) * C + c + 0]);
129 else if (C - c == 3) tqu = set(0.0, u[(a * B + b) * C + c + 2], u[(a * B + b) * C + c + 1], u[(a * B + b) * C + c + 0]);
130 else tqu = loadu(&u[(a * B + b) * C + c]);
131 for (CeedInt jj = 0; jj < JJ; jj++) { // unroll
132 fmadd(vv[jj], tqu, set1(t[(j + jj) * t_stride_0 + b * t_stride_1]));
133 }
134 }
135 for (CeedInt jj = 0; jj < JJ; jj++) storeu(&v[(a * J + j + jj) * C + c], vv[jj]);
136 }
137 }
138 // Remainder of rows, all columns
139 for (CeedInt j = J_break; j < J; j++) {
140 for (CeedInt b = 0; b < B; b++) {
141 const CeedScalar tq = t[j * t_stride_0 + b * t_stride_1];
142
143 for (CeedInt c = (C / CC) * CC; c < C; c++) v[(a * J + j) * C + c] += tq * u[(a * B + b) * C + c];
144 }
145 }
146 }
147 return CEED_ERROR_SUCCESS;
148 }
149
150 //------------------------------------------------------------------------------
151 // Serial Tensor Contract C=1
152 //------------------------------------------------------------------------------
CeedTensorContract_Avx_Single(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v,const CeedInt AA,const CeedInt JJ)153 static inline int CeedTensorContract_Avx_Single(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
154 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v,
155 const CeedInt AA, const CeedInt JJ) {
156 CeedInt t_stride_0 = B, t_stride_1 = 1;
157
158 if (t_mode == CEED_TRANSPOSE) {
159 t_stride_0 = 1;
160 t_stride_1 = J;
161 }
162
163 // Blocks of 4 rows
164 for (CeedInt a = 0; a < (A / AA) * AA; a += AA) {
165 for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) {
166 rtype vv[AA][JJ / 4]; // Output tile to be held in registers
167
168 for (CeedInt aa = 0; aa < AA; aa++) {
169 for (CeedInt jj = 0; jj < JJ / 4; jj++) vv[aa][jj] = loadu(&v[(a + aa) * J + j + jj * 4]);
170 }
171 for (CeedInt b = 0; b < B; b++) {
172 for (CeedInt jj = 0; jj < JJ / 4; jj++) { // unroll
173 rtype tqv = set(t[(j + jj * 4 + 3) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 2) * t_stride_0 + b * t_stride_1],
174 t[(j + jj * 4 + 1) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 0) * t_stride_0 + b * t_stride_1]);
175
176 for (CeedInt aa = 0; aa < AA; aa++) { // unroll
177 fmadd(vv[aa][jj], tqv, set1(u[(a + aa) * B + b]));
178 }
179 }
180 }
181 for (CeedInt aa = 0; aa < AA; aa++) {
182 for (CeedInt jj = 0; jj < JJ / 4; jj++) storeu(&v[(a + aa) * J + j + jj * 4], vv[aa][jj]);
183 }
184 }
185 }
186 // Remainder of rows
187 const CeedInt a = (A / AA) * AA;
188
189 for (CeedInt j = 0; j < (J / JJ) * JJ; j += JJ) {
190 rtype vv[AA][JJ / 4]; // Output tile to be held in registers
191
192 for (CeedInt aa = 0; aa < A - a; aa++) {
193 for (CeedInt jj = 0; jj < JJ / 4; jj++) vv[aa][jj] = loadu(&v[(a + aa) * J + j + jj * 4]);
194 }
195 for (CeedInt b = 0; b < B; b++) {
196 for (CeedInt jj = 0; jj < JJ / 4; jj++) { // unroll
197 rtype tqv = set(t[(j + jj * 4 + 3) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 2) * t_stride_0 + b * t_stride_1],
198 t[(j + jj * 4 + 1) * t_stride_0 + b * t_stride_1], t[(j + jj * 4 + 0) * t_stride_0 + b * t_stride_1]);
199
200 for (CeedInt aa = 0; aa < A - a; aa++) { // unroll
201 fmadd(vv[aa][jj], tqv, set1(u[(a + aa) * B + b]));
202 }
203 }
204 }
205 for (CeedInt aa = 0; aa < A - a; aa++) {
206 for (CeedInt jj = 0; jj < JJ / 4; jj++) storeu(&v[(a + aa) * J + j + jj * 4], vv[aa][jj]);
207 }
208 }
209 // Column remainder
210 const CeedInt A_break = A % AA ? (A / AA) * AA : (A / AA - 1) * AA;
211
212 // Blocks of 4 columns
213 for (CeedInt j = (J / JJ) * JJ; j < J; j += 4) {
214 // Blocks of 4 rows
215 for (CeedInt a = 0; a < A_break; a += AA) {
216 rtype vv[AA]; // Output tile to be held in registers
217
218 for (CeedInt aa = 0; aa < AA; aa++) vv[aa] = loadu(&v[(a + aa) * J + j]);
219 for (CeedInt b = 0; b < B; b++) {
220 rtype tqv;
221
222 if (J - j == 1) {
223 tqv = set(0.0, 0.0, 0.0, t[(j + 0) * t_stride_0 + b * t_stride_1]);
224 } else if (J - j == 2) {
225 tqv = set(0.0, 0.0, t[(j + 1) * t_stride_0 + b * t_stride_1], t[(j + 0) * t_stride_0 + b * t_stride_1]);
226 } else if (J - 3 == j) {
227 tqv =
228 set(0.0, t[(j + 2) * t_stride_0 + b * t_stride_1], t[(j + 1) * t_stride_0 + b * t_stride_1], t[(j + 0) * t_stride_0 + b * t_stride_1]);
229 } else {
230 tqv = set(t[(j + 3) * t_stride_0 + b * t_stride_1], t[(j + 2) * t_stride_0 + b * t_stride_1], t[(j + 1) * t_stride_0 + b * t_stride_1],
231 t[(j + 0) * t_stride_0 + b * t_stride_1]);
232 }
233 for (CeedInt aa = 0; aa < AA; aa++) { // unroll
234 fmadd(vv[aa], tqv, set1(u[(a + aa) * B + b]));
235 }
236 }
237 for (CeedInt aa = 0; aa < AA; aa++) storeu(&v[(a + aa) * J + j], vv[aa]);
238 }
239 }
240 // Remainder of rows, all columns
241 for (CeedInt b = 0; b < B; b++) {
242 for (CeedInt j = (J / JJ) * JJ; j < J; j++) {
243 const CeedScalar tq = t[j * t_stride_0 + b * t_stride_1];
244
245 for (CeedInt a = A_break; a < A; a++) v[a * J + j] += tq * u[a * B + b];
246 }
247 }
248 return CEED_ERROR_SUCCESS;
249 }
250
251 //------------------------------------------------------------------------------
252 // Tensor Contract - Common Sizes
253 //------------------------------------------------------------------------------
CeedTensorContract_Avx_Blocked_4_8(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v)254 static int CeedTensorContract_Avx_Blocked_4_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
255 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
256 return CeedTensorContract_Avx_Blocked(contract, A, B, C, J, t, t_mode, add, u, v, 4, 8);
257 }
CeedTensorContract_Avx_Remainder_8_8(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v)258 static int CeedTensorContract_Avx_Remainder_8_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
259 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
260 return CeedTensorContract_Avx_Remainder(contract, A, B, C, J, t, t_mode, add, u, v, 8, 8);
261 }
CeedTensorContract_Avx_Single_4_8(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v)262 static int CeedTensorContract_Avx_Single_4_8(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
263 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
264 return CeedTensorContract_Avx_Single(contract, A, B, C, J, t, t_mode, add, u, v, 4, 8);
265 }
266
267 //------------------------------------------------------------------------------
268 // Tensor Contract Apply
269 //------------------------------------------------------------------------------
CeedTensorContractApply_Avx(CeedTensorContract contract,CeedInt A,CeedInt B,CeedInt C,CeedInt J,const CeedScalar * restrict t,CeedTransposeMode t_mode,const CeedInt add,const CeedScalar * restrict u,CeedScalar * restrict v)270 static int CeedTensorContractApply_Avx(CeedTensorContract contract, CeedInt A, CeedInt B, CeedInt C, CeedInt J, const CeedScalar *restrict t,
271 CeedTransposeMode t_mode, const CeedInt add, const CeedScalar *restrict u, CeedScalar *restrict v) {
272 const CeedInt blk_size = 8;
273
274 if (!add) {
275 for (CeedInt q = 0; q < A * J * C; q++) v[q] = (CeedScalar)0.0;
276 }
277
278 if (C == 1) {
279 // Serial C=1 Case
280 CeedTensorContract_Avx_Single_4_8(contract, A, B, C, J, t, t_mode, true, u, v);
281 } else {
282 // Blocks of 8 columns
283 if (C >= blk_size) CeedTensorContract_Avx_Blocked_4_8(contract, A, B, C, J, t, t_mode, true, u, v);
284 // Remainder of columns
285 if (C % blk_size) CeedTensorContract_Avx_Remainder_8_8(contract, A, B, C, J, t, t_mode, true, u, v);
286 }
287 return CEED_ERROR_SUCCESS;
288 }
289
290 //------------------------------------------------------------------------------
291 // Tensor Contract Create
292 //------------------------------------------------------------------------------
CeedTensorContractCreate_Avx(CeedTensorContract contract)293 int CeedTensorContractCreate_Avx(CeedTensorContract contract) {
294 CeedCallBackend(CeedSetBackendFunction(CeedTensorContractReturnCeed(contract), "TensorContract", contract, "Apply", CeedTensorContractApply_Avx));
295 return CEED_ERROR_SUCCESS;
296 }
297
298 //------------------------------------------------------------------------------
299