1*9ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3f80f4a74SSebastian Grimberg //
4f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5f80f4a74SSebastian Grimberg //
6f80f4a74SSebastian Grimberg // This file is part of CEED: http://github.com/ceed
7f80f4a74SSebastian Grimberg
83c1e2affSSebastian Grimberg /// @file
93c1e2affSSebastian Grimberg /// Internal header for MAGMA backend common non-tensor basis definitions
10509d4af6SJeremy L Thompson #pragma once
11f80f4a74SSebastian Grimberg
123c1e2affSSebastian Grimberg #include "magma-common-defs.h"
13f80f4a74SSebastian Grimberg
14f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
15f80f4a74SSebastian Grimberg // read A (no-trans) from global to reg.
163c1e2affSSebastian Grimberg // A is (P x Q)
17833aa127SSebastian Grimberg // 2D thread config. with (P x BY) threads
18f80f4a74SSebastian Grimberg // no sync at the end of the function
19833aa127SSebastian Grimberg template <typename T, int P, int Q, int BY>
read_A_notrans_g2r_1D_nosync(const int tx,const int ty,const T * dA,T * sA,T rA[Q])20833aa127SSebastian Grimberg static __device__ __inline__ void read_A_notrans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) {
21833aa127SSebastian Grimberg const int tid = ty * P + tx;
22833aa127SSebastian Grimberg int i;
23833aa127SSebastian Grimberg
24f80f4a74SSebastian Grimberg #pragma unroll
25833aa127SSebastian Grimberg for (i = 0; i < P * Q - P * BY; i += P * BY) {
26833aa127SSebastian Grimberg sA[i + tid] = dA[i + tid];
27833aa127SSebastian Grimberg }
28833aa127SSebastian Grimberg if (i + tid < P * Q) {
29833aa127SSebastian Grimberg sA[i + tid] = dA[i + tid];
30833aa127SSebastian Grimberg }
31833aa127SSebastian Grimberg __syncthreads();
32833aa127SSebastian Grimberg
33833aa127SSebastian Grimberg #pragma unroll
34833aa127SSebastian Grimberg for (int j = 0; j < Q; j++) {
351a0eda08SSebastian Grimberg rA[j] = sA[j * P + tx];
36f80f4a74SSebastian Grimberg }
37f80f4a74SSebastian Grimberg }
38f80f4a74SSebastian Grimberg
39f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
403c1e2affSSebastian Grimberg // read A (trans) from global to reg.
413c1e2affSSebastian Grimberg // A is (P x Q)
429d15e85bSSebastian Grimberg // 2D thread config. with (P x BY) threads
43f80f4a74SSebastian Grimberg // no sync at the end of the function
449d15e85bSSebastian Grimberg template <typename T, int P, int Q, int BY>
read_A_trans_g2r_1D_nosync(const int tx,const int ty,const T * dA,T * sA,T rA[Q])459d15e85bSSebastian Grimberg static __device__ __inline__ void read_A_trans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) {
46833aa127SSebastian Grimberg const int tid = ty * P + tx;
473c1e2affSSebastian Grimberg int i;
48f80f4a74SSebastian Grimberg
49f80f4a74SSebastian Grimberg #pragma unroll
509d15e85bSSebastian Grimberg for (i = 0; i < P * Q - P * BY; i += P * BY) {
513c1e2affSSebastian Grimberg sA[i + tid] = dA[i + tid];
52f80f4a74SSebastian Grimberg }
539d15e85bSSebastian Grimberg if (i + tid < P * Q) {
543c1e2affSSebastian Grimberg sA[i + tid] = dA[i + tid];
55f80f4a74SSebastian Grimberg }
56f80f4a74SSebastian Grimberg __syncthreads();
57f80f4a74SSebastian Grimberg
58f80f4a74SSebastian Grimberg #pragma unroll
593c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) {
609d15e85bSSebastian Grimberg rA[j] = sA[tx * Q + j];
61f80f4a74SSebastian Grimberg }
62f80f4a74SSebastian Grimberg }
63f80f4a74SSebastian Grimberg
64f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
65f80f4a74SSebastian Grimberg // read B from global to shared
663c1e2affSSebastian Grimberg // B is (Q x NB)
673c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads
68f80f4a74SSebastian Grimberg // no sync at the end of the function
693c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
read_B_g2s_1D_nosync(const int tx,const int n,const T * dB,T * sB)709d15e85bSSebastian Grimberg static __device__ __inline__ void read_B_g2s_1D_nosync(const int tx, const int n, const T *dB, T *sB) {
719d15e85bSSebastian Grimberg int i;
729d15e85bSSebastian Grimberg
733c1e2affSSebastian Grimberg if (n != NB) {
749d15e85bSSebastian Grimberg for (i = 0; i < Q * n - P; i += P) {
75f80f4a74SSebastian Grimberg sB[i + tx] = dB[i + tx];
76f80f4a74SSebastian Grimberg }
77f80f4a74SSebastian Grimberg } else {
78f80f4a74SSebastian Grimberg #pragma unroll
799d15e85bSSebastian Grimberg for (i = 0; i < Q * NB - P; i += P) {
80f80f4a74SSebastian Grimberg sB[i + tx] = dB[i + tx];
81f80f4a74SSebastian Grimberg }
82f80f4a74SSebastian Grimberg }
839d15e85bSSebastian Grimberg if (i + tx < Q * n) {
849d15e85bSSebastian Grimberg sB[i + tx] = dB[i + tx];
85f80f4a74SSebastian Grimberg }
86f80f4a74SSebastian Grimberg }
87f80f4a74SSebastian Grimberg
88f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
893c1e2affSSebastian Grimberg // write C from reg. to global
903c1e2affSSebastian Grimberg // C is (P x NB)
913c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads
923c1e2affSSebastian Grimberg // no sync at the end of the function
933c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
write_C_r2g_1D_nosync(const int tx,const int n,T rC[NB],T * dC)949d15e85bSSebastian Grimberg static __device__ __inline__ void write_C_r2g_1D_nosync(const int tx, const int n, T rC[NB], T *dC) {
953c1e2affSSebastian Grimberg if (n != NB) {
969d15e85bSSebastian Grimberg for (int i = 0; i < n; i++) {
979d15e85bSSebastian Grimberg dC[i * P + tx] = rC[i];
983c1e2affSSebastian Grimberg }
993c1e2affSSebastian Grimberg } else {
1003c1e2affSSebastian Grimberg #pragma unroll
1019d15e85bSSebastian Grimberg for (int i = 0; i < NB; i++) {
1029d15e85bSSebastian Grimberg dC[i * P + tx] = rC[i];
1033c1e2affSSebastian Grimberg }
1043c1e2affSSebastian Grimberg }
1053c1e2affSSebastian Grimberg }
1063c1e2affSSebastian Grimberg
1073c1e2affSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
108db2becc9SJeremy L Thompson // sum into C from reg. to global
109db2becc9SJeremy L Thompson // C is (P x NB)
110db2becc9SJeremy L Thompson // 1D thread config. with (P x 1) threads
111db2becc9SJeremy L Thompson // no sync at the end of the function
112db2becc9SJeremy L Thompson template <typename T, int P, int Q, int NB>
sum_C_r2g_1D_nosync(const int tx,const int n,T rC[NB],T * dC)113db2becc9SJeremy L Thompson static __device__ __inline__ void sum_C_r2g_1D_nosync(const int tx, const int n, T rC[NB], T *dC) {
114db2becc9SJeremy L Thompson if (n != NB) {
115db2becc9SJeremy L Thompson for (int i = 0; i < n; i++) {
116db2becc9SJeremy L Thompson dC[i * P + tx] += rC[i];
117db2becc9SJeremy L Thompson }
118db2becc9SJeremy L Thompson } else {
119db2becc9SJeremy L Thompson #pragma unroll
120db2becc9SJeremy L Thompson for (int i = 0; i < NB; i++) {
121db2becc9SJeremy L Thompson dC[i * P + tx] += rC[i];
122db2becc9SJeremy L Thompson }
123db2becc9SJeremy L Thompson }
124db2becc9SJeremy L Thompson }
125db2becc9SJeremy L Thompson
126db2becc9SJeremy L Thompson ////////////////////////////////////////////////////////////////////////////////
1273c1e2affSSebastian Grimberg // multiply C = A x B using 1D threads in P x 1 config
1283c1e2affSSebastian Grimberg // A (P x Q) in reg., one row per thread
1293c1e2affSSebastian Grimberg // B (Q x NB) in shared memory
130f80f4a74SSebastian Grimberg // C in registers -- one row per thread
131f80f4a74SSebastian Grimberg // no sync at the end of the function
1323c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
mul_rAsBrC_1D_nosync(T rA[Q],T * sB,T rC[NB])1339d15e85bSSebastian Grimberg static __device__ __inline__ void mul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) {
1343c1e2affSSebastian Grimberg T rB[Q];
1359d15e85bSSebastian Grimberg
136f80f4a74SSebastian Grimberg #pragma unroll
1373c1e2affSSebastian Grimberg for (int i = 0; i < NB; i++) {
138f80f4a74SSebastian Grimberg #pragma unroll
1399d15e85bSSebastian Grimberg for (int j = 0; j < Q; j++) {
1409d15e85bSSebastian Grimberg rB[j] = sB[i * Q + j];
141f80f4a74SSebastian Grimberg }
1423c1e2affSSebastian Grimberg rC[i] = 0.0;
143f80f4a74SSebastian Grimberg #pragma unroll
1449d15e85bSSebastian Grimberg for (int j = 0; j < Q; j++) {
1459d15e85bSSebastian Grimberg rC[i] += rA[j] * rB[j];
146f80f4a74SSebastian Grimberg }
147f80f4a74SSebastian Grimberg }
148f80f4a74SSebastian Grimberg }
149f80f4a74SSebastian Grimberg
1503c1e2affSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
1513c1e2affSSebastian Grimberg // multiply C += A x B using 1D threads in P x 1 config
1523c1e2affSSebastian Grimberg // A (P x Q) in reg., one row per thread
1533c1e2affSSebastian Grimberg // B (Q x NB) in shared memory
1543c1e2affSSebastian Grimberg // C in registers -- one row per thread
1553c1e2affSSebastian Grimberg // no sync at the end of the function
1563c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
addmul_rAsBrC_1D_nosync(T rA[Q],T * sB,T rC[NB])1579d15e85bSSebastian Grimberg static __device__ __inline__ void addmul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) {
1583c1e2affSSebastian Grimberg T rB[Q];
1599d15e85bSSebastian Grimberg
1603c1e2affSSebastian Grimberg #pragma unroll
1613c1e2affSSebastian Grimberg for (int i = 0; i < NB; i++) {
1623c1e2affSSebastian Grimberg #pragma unroll
1639d15e85bSSebastian Grimberg for (int j = 0; j < Q; j++) {
1649d15e85bSSebastian Grimberg rB[j] = sB[i * Q + j];
1653c1e2affSSebastian Grimberg }
1663c1e2affSSebastian Grimberg #pragma unroll
1679d15e85bSSebastian Grimberg for (int j = 0; j < Q; j++) {
1689d15e85bSSebastian Grimberg rC[i] += rA[j] * rB[j];
1693c1e2affSSebastian Grimberg }
1703c1e2affSSebastian Grimberg }
1713c1e2affSSebastian Grimberg }
172