xref: /libCEED/include/ceed/jit-source/magma/magma-common-nontensor.h (revision d4cc18453651bd0f94c1a2e078b2646a92dafdcc)
1*9ba83ac0SJeremy L Thompson // Copyright (c) 2017-2026, Lawrence Livermore National Security, LLC and other CEED contributors.
2f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3f80f4a74SSebastian Grimberg //
4f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5f80f4a74SSebastian Grimberg //
6f80f4a74SSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7f80f4a74SSebastian Grimberg 
83c1e2affSSebastian Grimberg /// @file
93c1e2affSSebastian Grimberg /// Internal header for MAGMA backend common non-tensor basis definitions
10509d4af6SJeremy L Thompson #pragma once
11f80f4a74SSebastian Grimberg 
123c1e2affSSebastian Grimberg #include "magma-common-defs.h"
13f80f4a74SSebastian Grimberg 
14f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
15f80f4a74SSebastian Grimberg // read A (no-trans) from global to reg.
163c1e2affSSebastian Grimberg // A is (P x Q)
17833aa127SSebastian Grimberg // 2D thread config. with (P x BY) threads
18f80f4a74SSebastian Grimberg // no sync at the end of the function
19833aa127SSebastian Grimberg template <typename T, int P, int Q, int BY>
read_A_notrans_g2r_1D_nosync(const int tx,const int ty,const T * dA,T * sA,T rA[Q])20833aa127SSebastian Grimberg static __device__ __inline__ void read_A_notrans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) {
21833aa127SSebastian Grimberg   const int tid = ty * P + tx;
22833aa127SSebastian Grimberg   int       i;
23833aa127SSebastian Grimberg 
24f80f4a74SSebastian Grimberg #pragma unroll
25833aa127SSebastian Grimberg   for (i = 0; i < P * Q - P * BY; i += P * BY) {
26833aa127SSebastian Grimberg     sA[i + tid] = dA[i + tid];
27833aa127SSebastian Grimberg   }
28833aa127SSebastian Grimberg   if (i + tid < P * Q) {
29833aa127SSebastian Grimberg     sA[i + tid] = dA[i + tid];
30833aa127SSebastian Grimberg   }
31833aa127SSebastian Grimberg   __syncthreads();
32833aa127SSebastian Grimberg 
33833aa127SSebastian Grimberg #pragma unroll
34833aa127SSebastian Grimberg   for (int j = 0; j < Q; j++) {
351a0eda08SSebastian Grimberg     rA[j] = sA[j * P + tx];
36f80f4a74SSebastian Grimberg   }
37f80f4a74SSebastian Grimberg }
38f80f4a74SSebastian Grimberg 
39f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
403c1e2affSSebastian Grimberg // read A (trans) from global to reg.
413c1e2affSSebastian Grimberg // A is (P x Q)
429d15e85bSSebastian Grimberg // 2D thread config. with (P x BY) threads
43f80f4a74SSebastian Grimberg // no sync at the end of the function
449d15e85bSSebastian Grimberg template <typename T, int P, int Q, int BY>
read_A_trans_g2r_1D_nosync(const int tx,const int ty,const T * dA,T * sA,T rA[Q])459d15e85bSSebastian Grimberg static __device__ __inline__ void read_A_trans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) {
46833aa127SSebastian Grimberg   const int tid = ty * P + tx;
473c1e2affSSebastian Grimberg   int       i;
48f80f4a74SSebastian Grimberg 
49f80f4a74SSebastian Grimberg #pragma unroll
509d15e85bSSebastian Grimberg   for (i = 0; i < P * Q - P * BY; i += P * BY) {
513c1e2affSSebastian Grimberg     sA[i + tid] = dA[i + tid];
52f80f4a74SSebastian Grimberg   }
539d15e85bSSebastian Grimberg   if (i + tid < P * Q) {
543c1e2affSSebastian Grimberg     sA[i + tid] = dA[i + tid];
55f80f4a74SSebastian Grimberg   }
56f80f4a74SSebastian Grimberg   __syncthreads();
57f80f4a74SSebastian Grimberg 
58f80f4a74SSebastian Grimberg #pragma unroll
593c1e2affSSebastian Grimberg   for (int j = 0; j < Q; j++) {
609d15e85bSSebastian Grimberg     rA[j] = sA[tx * Q + j];
61f80f4a74SSebastian Grimberg   }
62f80f4a74SSebastian Grimberg }
63f80f4a74SSebastian Grimberg 
64f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
65f80f4a74SSebastian Grimberg // read B from global to shared
663c1e2affSSebastian Grimberg // B is (Q x NB)
673c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads
68f80f4a74SSebastian Grimberg // no sync at the end of the function
693c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
read_B_g2s_1D_nosync(const int tx,const int n,const T * dB,T * sB)709d15e85bSSebastian Grimberg static __device__ __inline__ void read_B_g2s_1D_nosync(const int tx, const int n, const T *dB, T *sB) {
719d15e85bSSebastian Grimberg   int i;
729d15e85bSSebastian Grimberg 
733c1e2affSSebastian Grimberg   if (n != NB) {
749d15e85bSSebastian Grimberg     for (i = 0; i < Q * n - P; i += P) {
75f80f4a74SSebastian Grimberg       sB[i + tx] = dB[i + tx];
76f80f4a74SSebastian Grimberg     }
77f80f4a74SSebastian Grimberg   } else {
78f80f4a74SSebastian Grimberg #pragma unroll
799d15e85bSSebastian Grimberg     for (i = 0; i < Q * NB - P; i += P) {
80f80f4a74SSebastian Grimberg       sB[i + tx] = dB[i + tx];
81f80f4a74SSebastian Grimberg     }
82f80f4a74SSebastian Grimberg   }
839d15e85bSSebastian Grimberg   if (i + tx < Q * n) {
849d15e85bSSebastian Grimberg     sB[i + tx] = dB[i + tx];
85f80f4a74SSebastian Grimberg   }
86f80f4a74SSebastian Grimberg }
87f80f4a74SSebastian Grimberg 
88f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
893c1e2affSSebastian Grimberg // write C from reg. to global
903c1e2affSSebastian Grimberg // C is (P x NB)
913c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads
923c1e2affSSebastian Grimberg // no sync at the end of the function
933c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
write_C_r2g_1D_nosync(const int tx,const int n,T rC[NB],T * dC)949d15e85bSSebastian Grimberg static __device__ __inline__ void write_C_r2g_1D_nosync(const int tx, const int n, T rC[NB], T *dC) {
953c1e2affSSebastian Grimberg   if (n != NB) {
969d15e85bSSebastian Grimberg     for (int i = 0; i < n; i++) {
979d15e85bSSebastian Grimberg       dC[i * P + tx] = rC[i];
983c1e2affSSebastian Grimberg     }
993c1e2affSSebastian Grimberg   } else {
1003c1e2affSSebastian Grimberg #pragma unroll
1019d15e85bSSebastian Grimberg     for (int i = 0; i < NB; i++) {
1029d15e85bSSebastian Grimberg       dC[i * P + tx] = rC[i];
1033c1e2affSSebastian Grimberg     }
1043c1e2affSSebastian Grimberg   }
1053c1e2affSSebastian Grimberg }
1063c1e2affSSebastian Grimberg 
1073c1e2affSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
108db2becc9SJeremy L Thompson // sum into C from reg. to global
109db2becc9SJeremy L Thompson // C is (P x NB)
110db2becc9SJeremy L Thompson // 1D thread config. with (P x 1) threads
111db2becc9SJeremy L Thompson // no sync at the end of the function
112db2becc9SJeremy L Thompson template <typename T, int P, int Q, int NB>
sum_C_r2g_1D_nosync(const int tx,const int n,T rC[NB],T * dC)113db2becc9SJeremy L Thompson static __device__ __inline__ void sum_C_r2g_1D_nosync(const int tx, const int n, T rC[NB], T *dC) {
114db2becc9SJeremy L Thompson   if (n != NB) {
115db2becc9SJeremy L Thompson     for (int i = 0; i < n; i++) {
116db2becc9SJeremy L Thompson       dC[i * P + tx] += rC[i];
117db2becc9SJeremy L Thompson     }
118db2becc9SJeremy L Thompson   } else {
119db2becc9SJeremy L Thompson #pragma unroll
120db2becc9SJeremy L Thompson     for (int i = 0; i < NB; i++) {
121db2becc9SJeremy L Thompson       dC[i * P + tx] += rC[i];
122db2becc9SJeremy L Thompson     }
123db2becc9SJeremy L Thompson   }
124db2becc9SJeremy L Thompson }
125db2becc9SJeremy L Thompson 
126db2becc9SJeremy L Thompson ////////////////////////////////////////////////////////////////////////////////
1273c1e2affSSebastian Grimberg // multiply C = A x B using 1D threads in P x 1 config
1283c1e2affSSebastian Grimberg // A (P x Q)  in reg., one row per thread
1293c1e2affSSebastian Grimberg // B (Q x NB) in shared memory
130f80f4a74SSebastian Grimberg // C in registers -- one row per thread
131f80f4a74SSebastian Grimberg // no sync at the end of the function
1323c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
mul_rAsBrC_1D_nosync(T rA[Q],T * sB,T rC[NB])1339d15e85bSSebastian Grimberg static __device__ __inline__ void mul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) {
1343c1e2affSSebastian Grimberg   T rB[Q];
1359d15e85bSSebastian Grimberg 
136f80f4a74SSebastian Grimberg #pragma unroll
1373c1e2affSSebastian Grimberg   for (int i = 0; i < NB; i++) {
138f80f4a74SSebastian Grimberg #pragma unroll
1399d15e85bSSebastian Grimberg     for (int j = 0; j < Q; j++) {
1409d15e85bSSebastian Grimberg       rB[j] = sB[i * Q + j];
141f80f4a74SSebastian Grimberg     }
1423c1e2affSSebastian Grimberg     rC[i] = 0.0;
143f80f4a74SSebastian Grimberg #pragma unroll
1449d15e85bSSebastian Grimberg     for (int j = 0; j < Q; j++) {
1459d15e85bSSebastian Grimberg       rC[i] += rA[j] * rB[j];
146f80f4a74SSebastian Grimberg     }
147f80f4a74SSebastian Grimberg   }
148f80f4a74SSebastian Grimberg }
149f80f4a74SSebastian Grimberg 
1503c1e2affSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
1513c1e2affSSebastian Grimberg // multiply C += A x B using 1D threads in P x 1 config
1523c1e2affSSebastian Grimberg // A (P x Q)  in reg., one row per thread
1533c1e2affSSebastian Grimberg // B (Q x NB) in shared memory
1543c1e2affSSebastian Grimberg // C in registers -- one row per thread
1553c1e2affSSebastian Grimberg // no sync at the end of the function
1563c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
addmul_rAsBrC_1D_nosync(T rA[Q],T * sB,T rC[NB])1579d15e85bSSebastian Grimberg static __device__ __inline__ void addmul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) {
1583c1e2affSSebastian Grimberg   T rB[Q];
1599d15e85bSSebastian Grimberg 
1603c1e2affSSebastian Grimberg #pragma unroll
1613c1e2affSSebastian Grimberg   for (int i = 0; i < NB; i++) {
1623c1e2affSSebastian Grimberg #pragma unroll
1639d15e85bSSebastian Grimberg     for (int j = 0; j < Q; j++) {
1649d15e85bSSebastian Grimberg       rB[j] = sB[i * Q + j];
1653c1e2affSSebastian Grimberg     }
1663c1e2affSSebastian Grimberg #pragma unroll
1679d15e85bSSebastian Grimberg     for (int j = 0; j < Q; j++) {
1689d15e85bSSebastian Grimberg       rC[i] += rA[j] * rB[j];
1693c1e2affSSebastian Grimberg     }
1703c1e2affSSebastian Grimberg   }
1713c1e2affSSebastian Grimberg }
172