xref: /libCEED/include/ceed/jit-source/magma/magma-common-nontensor.h (revision 5aed82e4fa97acf4ba24a7f10a35f5303a6798e0)
1*5aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3f80f4a74SSebastian Grimberg //
4f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause
5f80f4a74SSebastian Grimberg //
6f80f4a74SSebastian Grimberg // This file is part of CEED:  http://github.com/ceed
7f80f4a74SSebastian Grimberg 
83c1e2affSSebastian Grimberg /// @file
93c1e2affSSebastian Grimberg /// Internal header for MAGMA backend common non-tensor basis definitions
10f80f4a74SSebastian Grimberg #ifndef CEED_MAGMA_COMMON_NONTENSOR_H
11f80f4a74SSebastian Grimberg #define CEED_MAGMA_COMMON_NONTENSOR_H
12f80f4a74SSebastian Grimberg 
133c1e2affSSebastian Grimberg #include "magma-common-defs.h"
14f80f4a74SSebastian Grimberg 
15f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
16f80f4a74SSebastian Grimberg // read A (no-trans) from global to reg.
173c1e2affSSebastian Grimberg // A is (P x Q)
18833aa127SSebastian Grimberg // 2D thread config. with (P x BY) threads
19f80f4a74SSebastian Grimberg // no sync at the end of the function
20833aa127SSebastian Grimberg template <typename T, int P, int Q, int BY>
21833aa127SSebastian Grimberg static __device__ __inline__ void read_A_notrans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) {
22833aa127SSebastian Grimberg   const int tid = ty * P + tx;
23833aa127SSebastian Grimberg   int       i;
24833aa127SSebastian Grimberg 
25f80f4a74SSebastian Grimberg #pragma unroll
26833aa127SSebastian Grimberg   for (i = 0; i < P * Q - P * BY; i += P * BY) {
27833aa127SSebastian Grimberg     sA[i + tid] = dA[i + tid];
28833aa127SSebastian Grimberg   }
29833aa127SSebastian Grimberg   if (i + tid < P * Q) {
30833aa127SSebastian Grimberg     sA[i + tid] = dA[i + tid];
31833aa127SSebastian Grimberg   }
32833aa127SSebastian Grimberg   __syncthreads();
33833aa127SSebastian Grimberg 
34833aa127SSebastian Grimberg #pragma unroll
35833aa127SSebastian Grimberg   for (int j = 0; j < Q; j++) {
361a0eda08SSebastian Grimberg     rA[j] = sA[j * P + tx];
37f80f4a74SSebastian Grimberg   }
38f80f4a74SSebastian Grimberg }
39f80f4a74SSebastian Grimberg 
40f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
413c1e2affSSebastian Grimberg // read A (trans) from global to reg.
423c1e2affSSebastian Grimberg // A is (P x Q)
439d15e85bSSebastian Grimberg // 2D thread config. with (P x BY) threads
44f80f4a74SSebastian Grimberg // no sync at the end of the function
459d15e85bSSebastian Grimberg template <typename T, int P, int Q, int BY>
469d15e85bSSebastian Grimberg static __device__ __inline__ void read_A_trans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) {
47833aa127SSebastian Grimberg   const int tid = ty * P + tx;
483c1e2affSSebastian Grimberg   int       i;
49f80f4a74SSebastian Grimberg 
50f80f4a74SSebastian Grimberg #pragma unroll
519d15e85bSSebastian Grimberg   for (i = 0; i < P * Q - P * BY; i += P * BY) {
523c1e2affSSebastian Grimberg     sA[i + tid] = dA[i + tid];
53f80f4a74SSebastian Grimberg   }
549d15e85bSSebastian Grimberg   if (i + tid < P * Q) {
553c1e2affSSebastian Grimberg     sA[i + tid] = dA[i + tid];
56f80f4a74SSebastian Grimberg   }
57f80f4a74SSebastian Grimberg   __syncthreads();
58f80f4a74SSebastian Grimberg 
59f80f4a74SSebastian Grimberg #pragma unroll
603c1e2affSSebastian Grimberg   for (int j = 0; j < Q; j++) {
619d15e85bSSebastian Grimberg     rA[j] = sA[tx * Q + j];
62f80f4a74SSebastian Grimberg   }
63f80f4a74SSebastian Grimberg }
64f80f4a74SSebastian Grimberg 
65f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
66f80f4a74SSebastian Grimberg // read B from global to shared
673c1e2affSSebastian Grimberg // B is (Q x NB)
683c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads
69f80f4a74SSebastian Grimberg // no sync at the end of the function
703c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
719d15e85bSSebastian Grimberg static __device__ __inline__ void read_B_g2s_1D_nosync(const int tx, const int n, const T *dB, T *sB) {
729d15e85bSSebastian Grimberg   int i;
739d15e85bSSebastian Grimberg 
743c1e2affSSebastian Grimberg   if (n != NB) {
759d15e85bSSebastian Grimberg     for (i = 0; i < Q * n - P; i += P) {
76f80f4a74SSebastian Grimberg       sB[i + tx] = dB[i + tx];
77f80f4a74SSebastian Grimberg     }
78f80f4a74SSebastian Grimberg   } else {
79f80f4a74SSebastian Grimberg #pragma unroll
809d15e85bSSebastian Grimberg     for (i = 0; i < Q * NB - P; i += P) {
81f80f4a74SSebastian Grimberg       sB[i + tx] = dB[i + tx];
82f80f4a74SSebastian Grimberg     }
83f80f4a74SSebastian Grimberg   }
849d15e85bSSebastian Grimberg   if (i + tx < Q * n) {
859d15e85bSSebastian Grimberg     sB[i + tx] = dB[i + tx];
86f80f4a74SSebastian Grimberg   }
87f80f4a74SSebastian Grimberg }
88f80f4a74SSebastian Grimberg 
89f80f4a74SSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
903c1e2affSSebastian Grimberg // write C from reg. to global
913c1e2affSSebastian Grimberg // C is (P x NB)
923c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads
933c1e2affSSebastian Grimberg // no sync at the end of the function
943c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
959d15e85bSSebastian Grimberg static __device__ __inline__ void write_C_r2g_1D_nosync(const int tx, const int n, T rC[NB], T *dC) {
963c1e2affSSebastian Grimberg   if (n != NB) {
979d15e85bSSebastian Grimberg     for (int i = 0; i < n; i++) {
989d15e85bSSebastian Grimberg       dC[i * P + tx] = rC[i];
993c1e2affSSebastian Grimberg     }
1003c1e2affSSebastian Grimberg   } else {
1013c1e2affSSebastian Grimberg #pragma unroll
1029d15e85bSSebastian Grimberg     for (int i = 0; i < NB; i++) {
1039d15e85bSSebastian Grimberg       dC[i * P + tx] = rC[i];
1043c1e2affSSebastian Grimberg     }
1053c1e2affSSebastian Grimberg   }
1063c1e2affSSebastian Grimberg }
1073c1e2affSSebastian Grimberg 
1083c1e2affSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
1093c1e2affSSebastian Grimberg // multiply C = A x B using 1D threads in P x 1 config
1103c1e2affSSebastian Grimberg // A (P x Q)  in reg., one row per thread
1113c1e2affSSebastian Grimberg // B (Q x NB) in shared memory
112f80f4a74SSebastian Grimberg // C in registers -- one row per thread
113f80f4a74SSebastian Grimberg // no sync at the end of the function
1143c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
1159d15e85bSSebastian Grimberg static __device__ __inline__ void mul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) {
1163c1e2affSSebastian Grimberg   T rB[Q];
1179d15e85bSSebastian Grimberg 
118f80f4a74SSebastian Grimberg #pragma unroll
1193c1e2affSSebastian Grimberg   for (int i = 0; i < NB; i++) {
120f80f4a74SSebastian Grimberg #pragma unroll
1219d15e85bSSebastian Grimberg     for (int j = 0; j < Q; j++) {
1229d15e85bSSebastian Grimberg       rB[j] = sB[i * Q + j];
123f80f4a74SSebastian Grimberg     }
1243c1e2affSSebastian Grimberg     rC[i] = 0.0;
125f80f4a74SSebastian Grimberg #pragma unroll
1269d15e85bSSebastian Grimberg     for (int j = 0; j < Q; j++) {
1279d15e85bSSebastian Grimberg       rC[i] += rA[j] * rB[j];
128f80f4a74SSebastian Grimberg     }
129f80f4a74SSebastian Grimberg   }
130f80f4a74SSebastian Grimberg }
131f80f4a74SSebastian Grimberg 
1323c1e2affSSebastian Grimberg ////////////////////////////////////////////////////////////////////////////////
1333c1e2affSSebastian Grimberg // multiply C += A x B using 1D threads in P x 1 config
1343c1e2affSSebastian Grimberg // A (P x Q)  in reg., one row per thread
1353c1e2affSSebastian Grimberg // B (Q x NB) in shared memory
1363c1e2affSSebastian Grimberg // C in registers -- one row per thread
1373c1e2affSSebastian Grimberg // no sync at the end of the function
1383c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB>
1399d15e85bSSebastian Grimberg static __device__ __inline__ void addmul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) {
1403c1e2affSSebastian Grimberg   T rB[Q];
1419d15e85bSSebastian Grimberg 
1423c1e2affSSebastian Grimberg #pragma unroll
1433c1e2affSSebastian Grimberg   for (int i = 0; i < NB; i++) {
1443c1e2affSSebastian Grimberg #pragma unroll
1459d15e85bSSebastian Grimberg     for (int j = 0; j < Q; j++) {
1469d15e85bSSebastian Grimberg       rB[j] = sB[i * Q + j];
1473c1e2affSSebastian Grimberg     }
1483c1e2affSSebastian Grimberg #pragma unroll
1499d15e85bSSebastian Grimberg     for (int j = 0; j < Q; j++) {
1509d15e85bSSebastian Grimberg       rC[i] += rA[j] * rB[j];
1513c1e2affSSebastian Grimberg     }
1523c1e2affSSebastian Grimberg   }
1533c1e2affSSebastian Grimberg }
154f80f4a74SSebastian Grimberg 
155f80f4a74SSebastian Grimberg #endif  // CEED_MAGMA_COMMON_NONTENSOR_H
156