15aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3f80f4a74SSebastian Grimberg // 4f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5f80f4a74SSebastian Grimberg // 6f80f4a74SSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7f80f4a74SSebastian Grimberg 83c1e2affSSebastian Grimberg /// @file 93c1e2affSSebastian Grimberg /// Internal header for MAGMA backend common non-tensor basis definitions 10509d4af6SJeremy L Thompson #pragma once 11f80f4a74SSebastian Grimberg 123c1e2affSSebastian Grimberg #include "magma-common-defs.h" 13f80f4a74SSebastian Grimberg 14f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 15f80f4a74SSebastian Grimberg // read A (no-trans) from global to reg. 163c1e2affSSebastian Grimberg // A is (P x Q) 17833aa127SSebastian Grimberg // 2D thread config. with (P x BY) threads 18f80f4a74SSebastian Grimberg // no sync at the end of the function 19833aa127SSebastian Grimberg template <typename T, int P, int Q, int BY> 20833aa127SSebastian Grimberg static __device__ __inline__ void read_A_notrans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) { 21833aa127SSebastian Grimberg const int tid = ty * P + tx; 22833aa127SSebastian Grimberg int i; 23833aa127SSebastian Grimberg 24f80f4a74SSebastian Grimberg #pragma unroll 25833aa127SSebastian Grimberg for (i = 0; i < P * Q - P * BY; i += P * BY) { 26833aa127SSebastian Grimberg sA[i + tid] = dA[i + tid]; 27833aa127SSebastian Grimberg } 28833aa127SSebastian Grimberg if (i + tid < P * Q) { 29833aa127SSebastian Grimberg sA[i + tid] = dA[i + tid]; 30833aa127SSebastian Grimberg } 31833aa127SSebastian Grimberg __syncthreads(); 32833aa127SSebastian Grimberg 33833aa127SSebastian Grimberg #pragma unroll 34833aa127SSebastian Grimberg for (int j = 0; j < Q; j++) { 351a0eda08SSebastian Grimberg rA[j] = sA[j * P + tx]; 36f80f4a74SSebastian Grimberg } 37f80f4a74SSebastian Grimberg } 38f80f4a74SSebastian Grimberg 39f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 403c1e2affSSebastian Grimberg // read A (trans) from global to reg. 413c1e2affSSebastian Grimberg // A is (P x Q) 429d15e85bSSebastian Grimberg // 2D thread config. with (P x BY) threads 43f80f4a74SSebastian Grimberg // no sync at the end of the function 449d15e85bSSebastian Grimberg template <typename T, int P, int Q, int BY> 459d15e85bSSebastian Grimberg static __device__ __inline__ void read_A_trans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) { 46833aa127SSebastian Grimberg const int tid = ty * P + tx; 473c1e2affSSebastian Grimberg int i; 48f80f4a74SSebastian Grimberg 49f80f4a74SSebastian Grimberg #pragma unroll 509d15e85bSSebastian Grimberg for (i = 0; i < P * Q - P * BY; i += P * BY) { 513c1e2affSSebastian Grimberg sA[i + tid] = dA[i + tid]; 52f80f4a74SSebastian Grimberg } 539d15e85bSSebastian Grimberg if (i + tid < P * Q) { 543c1e2affSSebastian Grimberg sA[i + tid] = dA[i + tid]; 55f80f4a74SSebastian Grimberg } 56f80f4a74SSebastian Grimberg __syncthreads(); 57f80f4a74SSebastian Grimberg 58f80f4a74SSebastian Grimberg #pragma unroll 593c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) { 609d15e85bSSebastian Grimberg rA[j] = sA[tx * Q + j]; 61f80f4a74SSebastian Grimberg } 62f80f4a74SSebastian Grimberg } 63f80f4a74SSebastian Grimberg 64f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 65f80f4a74SSebastian Grimberg // read B from global to shared 663c1e2affSSebastian Grimberg // B is (Q x NB) 673c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads 68f80f4a74SSebastian Grimberg // no sync at the end of the function 693c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB> 709d15e85bSSebastian Grimberg static __device__ __inline__ void read_B_g2s_1D_nosync(const int tx, const int n, const T *dB, T *sB) { 719d15e85bSSebastian Grimberg int i; 729d15e85bSSebastian Grimberg 733c1e2affSSebastian Grimberg if (n != NB) { 749d15e85bSSebastian Grimberg for (i = 0; i < Q * n - P; i += P) { 75f80f4a74SSebastian Grimberg sB[i + tx] = dB[i + tx]; 76f80f4a74SSebastian Grimberg } 77f80f4a74SSebastian Grimberg } else { 78f80f4a74SSebastian Grimberg #pragma unroll 799d15e85bSSebastian Grimberg for (i = 0; i < Q * NB - P; i += P) { 80f80f4a74SSebastian Grimberg sB[i + tx] = dB[i + tx]; 81f80f4a74SSebastian Grimberg } 82f80f4a74SSebastian Grimberg } 839d15e85bSSebastian Grimberg if (i + tx < Q * n) { 849d15e85bSSebastian Grimberg sB[i + tx] = dB[i + tx]; 85f80f4a74SSebastian Grimberg } 86f80f4a74SSebastian Grimberg } 87f80f4a74SSebastian Grimberg 88f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 893c1e2affSSebastian Grimberg // write C from reg. to global 903c1e2affSSebastian Grimberg // C is (P x NB) 913c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads 923c1e2affSSebastian Grimberg // no sync at the end of the function 933c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB> 949d15e85bSSebastian Grimberg static __device__ __inline__ void write_C_r2g_1D_nosync(const int tx, const int n, T rC[NB], T *dC) { 953c1e2affSSebastian Grimberg if (n != NB) { 969d15e85bSSebastian Grimberg for (int i = 0; i < n; i++) { 979d15e85bSSebastian Grimberg dC[i * P + tx] = rC[i]; 983c1e2affSSebastian Grimberg } 993c1e2affSSebastian Grimberg } else { 1003c1e2affSSebastian Grimberg #pragma unroll 1019d15e85bSSebastian Grimberg for (int i = 0; i < NB; i++) { 1029d15e85bSSebastian Grimberg dC[i * P + tx] = rC[i]; 1033c1e2affSSebastian Grimberg } 1043c1e2affSSebastian Grimberg } 1053c1e2affSSebastian Grimberg } 1063c1e2affSSebastian Grimberg 1073c1e2affSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 108*db2becc9SJeremy L Thompson // sum into C from reg. to global 109*db2becc9SJeremy L Thompson // C is (P x NB) 110*db2becc9SJeremy L Thompson // 1D thread config. with (P x 1) threads 111*db2becc9SJeremy L Thompson // no sync at the end of the function 112*db2becc9SJeremy L Thompson template <typename T, int P, int Q, int NB> 113*db2becc9SJeremy L Thompson static __device__ __inline__ void sum_C_r2g_1D_nosync(const int tx, const int n, T rC[NB], T *dC) { 114*db2becc9SJeremy L Thompson if (n != NB) { 115*db2becc9SJeremy L Thompson for (int i = 0; i < n; i++) { 116*db2becc9SJeremy L Thompson dC[i * P + tx] += rC[i]; 117*db2becc9SJeremy L Thompson } 118*db2becc9SJeremy L Thompson } else { 119*db2becc9SJeremy L Thompson #pragma unroll 120*db2becc9SJeremy L Thompson for (int i = 0; i < NB; i++) { 121*db2becc9SJeremy L Thompson dC[i * P + tx] += rC[i]; 122*db2becc9SJeremy L Thompson } 123*db2becc9SJeremy L Thompson } 124*db2becc9SJeremy L Thompson } 125*db2becc9SJeremy L Thompson 126*db2becc9SJeremy L Thompson //////////////////////////////////////////////////////////////////////////////// 1273c1e2affSSebastian Grimberg // multiply C = A x B using 1D threads in P x 1 config 1283c1e2affSSebastian Grimberg // A (P x Q) in reg., one row per thread 1293c1e2affSSebastian Grimberg // B (Q x NB) in shared memory 130f80f4a74SSebastian Grimberg // C in registers -- one row per thread 131f80f4a74SSebastian Grimberg // no sync at the end of the function 1323c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB> 1339d15e85bSSebastian Grimberg static __device__ __inline__ void mul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) { 1343c1e2affSSebastian Grimberg T rB[Q]; 1359d15e85bSSebastian Grimberg 136f80f4a74SSebastian Grimberg #pragma unroll 1373c1e2affSSebastian Grimberg for (int i = 0; i < NB; i++) { 138f80f4a74SSebastian Grimberg #pragma unroll 1399d15e85bSSebastian Grimberg for (int j = 0; j < Q; j++) { 1409d15e85bSSebastian Grimberg rB[j] = sB[i * Q + j]; 141f80f4a74SSebastian Grimberg } 1423c1e2affSSebastian Grimberg rC[i] = 0.0; 143f80f4a74SSebastian Grimberg #pragma unroll 1449d15e85bSSebastian Grimberg for (int j = 0; j < Q; j++) { 1459d15e85bSSebastian Grimberg rC[i] += rA[j] * rB[j]; 146f80f4a74SSebastian Grimberg } 147f80f4a74SSebastian Grimberg } 148f80f4a74SSebastian Grimberg } 149f80f4a74SSebastian Grimberg 1503c1e2affSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 1513c1e2affSSebastian Grimberg // multiply C += A x B using 1D threads in P x 1 config 1523c1e2affSSebastian Grimberg // A (P x Q) in reg., one row per thread 1533c1e2affSSebastian Grimberg // B (Q x NB) in shared memory 1543c1e2affSSebastian Grimberg // C in registers -- one row per thread 1553c1e2affSSebastian Grimberg // no sync at the end of the function 1563c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB> 1579d15e85bSSebastian Grimberg static __device__ __inline__ void addmul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) { 1583c1e2affSSebastian Grimberg T rB[Q]; 1599d15e85bSSebastian Grimberg 1603c1e2affSSebastian Grimberg #pragma unroll 1613c1e2affSSebastian Grimberg for (int i = 0; i < NB; i++) { 1623c1e2affSSebastian Grimberg #pragma unroll 1639d15e85bSSebastian Grimberg for (int j = 0; j < Q; j++) { 1649d15e85bSSebastian Grimberg rB[j] = sB[i * Q + j]; 1653c1e2affSSebastian Grimberg } 1663c1e2affSSebastian Grimberg #pragma unroll 1679d15e85bSSebastian Grimberg for (int j = 0; j < Q; j++) { 1689d15e85bSSebastian Grimberg rC[i] += rA[j] * rB[j]; 1693c1e2affSSebastian Grimberg } 1703c1e2affSSebastian Grimberg } 1713c1e2affSSebastian Grimberg } 172