1*5aed82e4SJeremy L Thompson // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors. 2f80f4a74SSebastian Grimberg // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3f80f4a74SSebastian Grimberg // 4f80f4a74SSebastian Grimberg // SPDX-License-Identifier: BSD-2-Clause 5f80f4a74SSebastian Grimberg // 6f80f4a74SSebastian Grimberg // This file is part of CEED: http://github.com/ceed 7f80f4a74SSebastian Grimberg 83c1e2affSSebastian Grimberg /// @file 93c1e2affSSebastian Grimberg /// Internal header for MAGMA backend common non-tensor basis definitions 10f80f4a74SSebastian Grimberg #ifndef CEED_MAGMA_COMMON_NONTENSOR_H 11f80f4a74SSebastian Grimberg #define CEED_MAGMA_COMMON_NONTENSOR_H 12f80f4a74SSebastian Grimberg 133c1e2affSSebastian Grimberg #include "magma-common-defs.h" 14f80f4a74SSebastian Grimberg 15f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 16f80f4a74SSebastian Grimberg // read A (no-trans) from global to reg. 173c1e2affSSebastian Grimberg // A is (P x Q) 18833aa127SSebastian Grimberg // 2D thread config. with (P x BY) threads 19f80f4a74SSebastian Grimberg // no sync at the end of the function 20833aa127SSebastian Grimberg template <typename T, int P, int Q, int BY> 21833aa127SSebastian Grimberg static __device__ __inline__ void read_A_notrans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) { 22833aa127SSebastian Grimberg const int tid = ty * P + tx; 23833aa127SSebastian Grimberg int i; 24833aa127SSebastian Grimberg 25f80f4a74SSebastian Grimberg #pragma unroll 26833aa127SSebastian Grimberg for (i = 0; i < P * Q - P * BY; i += P * BY) { 27833aa127SSebastian Grimberg sA[i + tid] = dA[i + tid]; 28833aa127SSebastian Grimberg } 29833aa127SSebastian Grimberg if (i + tid < P * Q) { 30833aa127SSebastian Grimberg sA[i + tid] = dA[i + tid]; 31833aa127SSebastian Grimberg } 32833aa127SSebastian Grimberg __syncthreads(); 33833aa127SSebastian Grimberg 34833aa127SSebastian Grimberg #pragma unroll 35833aa127SSebastian Grimberg for (int j = 0; j < Q; j++) { 361a0eda08SSebastian Grimberg rA[j] = sA[j * P + tx]; 37f80f4a74SSebastian Grimberg } 38f80f4a74SSebastian Grimberg } 39f80f4a74SSebastian Grimberg 40f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 413c1e2affSSebastian Grimberg // read A (trans) from global to reg. 423c1e2affSSebastian Grimberg // A is (P x Q) 439d15e85bSSebastian Grimberg // 2D thread config. with (P x BY) threads 44f80f4a74SSebastian Grimberg // no sync at the end of the function 459d15e85bSSebastian Grimberg template <typename T, int P, int Q, int BY> 469d15e85bSSebastian Grimberg static __device__ __inline__ void read_A_trans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) { 47833aa127SSebastian Grimberg const int tid = ty * P + tx; 483c1e2affSSebastian Grimberg int i; 49f80f4a74SSebastian Grimberg 50f80f4a74SSebastian Grimberg #pragma unroll 519d15e85bSSebastian Grimberg for (i = 0; i < P * Q - P * BY; i += P * BY) { 523c1e2affSSebastian Grimberg sA[i + tid] = dA[i + tid]; 53f80f4a74SSebastian Grimberg } 549d15e85bSSebastian Grimberg if (i + tid < P * Q) { 553c1e2affSSebastian Grimberg sA[i + tid] = dA[i + tid]; 56f80f4a74SSebastian Grimberg } 57f80f4a74SSebastian Grimberg __syncthreads(); 58f80f4a74SSebastian Grimberg 59f80f4a74SSebastian Grimberg #pragma unroll 603c1e2affSSebastian Grimberg for (int j = 0; j < Q; j++) { 619d15e85bSSebastian Grimberg rA[j] = sA[tx * Q + j]; 62f80f4a74SSebastian Grimberg } 63f80f4a74SSebastian Grimberg } 64f80f4a74SSebastian Grimberg 65f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 66f80f4a74SSebastian Grimberg // read B from global to shared 673c1e2affSSebastian Grimberg // B is (Q x NB) 683c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads 69f80f4a74SSebastian Grimberg // no sync at the end of the function 703c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB> 719d15e85bSSebastian Grimberg static __device__ __inline__ void read_B_g2s_1D_nosync(const int tx, const int n, const T *dB, T *sB) { 729d15e85bSSebastian Grimberg int i; 739d15e85bSSebastian Grimberg 743c1e2affSSebastian Grimberg if (n != NB) { 759d15e85bSSebastian Grimberg for (i = 0; i < Q * n - P; i += P) { 76f80f4a74SSebastian Grimberg sB[i + tx] = dB[i + tx]; 77f80f4a74SSebastian Grimberg } 78f80f4a74SSebastian Grimberg } else { 79f80f4a74SSebastian Grimberg #pragma unroll 809d15e85bSSebastian Grimberg for (i = 0; i < Q * NB - P; i += P) { 81f80f4a74SSebastian Grimberg sB[i + tx] = dB[i + tx]; 82f80f4a74SSebastian Grimberg } 83f80f4a74SSebastian Grimberg } 849d15e85bSSebastian Grimberg if (i + tx < Q * n) { 859d15e85bSSebastian Grimberg sB[i + tx] = dB[i + tx]; 86f80f4a74SSebastian Grimberg } 87f80f4a74SSebastian Grimberg } 88f80f4a74SSebastian Grimberg 89f80f4a74SSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 903c1e2affSSebastian Grimberg // write C from reg. to global 913c1e2affSSebastian Grimberg // C is (P x NB) 923c1e2affSSebastian Grimberg // 1D thread config. with (P x 1) threads 933c1e2affSSebastian Grimberg // no sync at the end of the function 943c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB> 959d15e85bSSebastian Grimberg static __device__ __inline__ void write_C_r2g_1D_nosync(const int tx, const int n, T rC[NB], T *dC) { 963c1e2affSSebastian Grimberg if (n != NB) { 979d15e85bSSebastian Grimberg for (int i = 0; i < n; i++) { 989d15e85bSSebastian Grimberg dC[i * P + tx] = rC[i]; 993c1e2affSSebastian Grimberg } 1003c1e2affSSebastian Grimberg } else { 1013c1e2affSSebastian Grimberg #pragma unroll 1029d15e85bSSebastian Grimberg for (int i = 0; i < NB; i++) { 1039d15e85bSSebastian Grimberg dC[i * P + tx] = rC[i]; 1043c1e2affSSebastian Grimberg } 1053c1e2affSSebastian Grimberg } 1063c1e2affSSebastian Grimberg } 1073c1e2affSSebastian Grimberg 1083c1e2affSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 1093c1e2affSSebastian Grimberg // multiply C = A x B using 1D threads in P x 1 config 1103c1e2affSSebastian Grimberg // A (P x Q) in reg., one row per thread 1113c1e2affSSebastian Grimberg // B (Q x NB) in shared memory 112f80f4a74SSebastian Grimberg // C in registers -- one row per thread 113f80f4a74SSebastian Grimberg // no sync at the end of the function 1143c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB> 1159d15e85bSSebastian Grimberg static __device__ __inline__ void mul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) { 1163c1e2affSSebastian Grimberg T rB[Q]; 1179d15e85bSSebastian Grimberg 118f80f4a74SSebastian Grimberg #pragma unroll 1193c1e2affSSebastian Grimberg for (int i = 0; i < NB; i++) { 120f80f4a74SSebastian Grimberg #pragma unroll 1219d15e85bSSebastian Grimberg for (int j = 0; j < Q; j++) { 1229d15e85bSSebastian Grimberg rB[j] = sB[i * Q + j]; 123f80f4a74SSebastian Grimberg } 1243c1e2affSSebastian Grimberg rC[i] = 0.0; 125f80f4a74SSebastian Grimberg #pragma unroll 1269d15e85bSSebastian Grimberg for (int j = 0; j < Q; j++) { 1279d15e85bSSebastian Grimberg rC[i] += rA[j] * rB[j]; 128f80f4a74SSebastian Grimberg } 129f80f4a74SSebastian Grimberg } 130f80f4a74SSebastian Grimberg } 131f80f4a74SSebastian Grimberg 1323c1e2affSSebastian Grimberg //////////////////////////////////////////////////////////////////////////////// 1333c1e2affSSebastian Grimberg // multiply C += A x B using 1D threads in P x 1 config 1343c1e2affSSebastian Grimberg // A (P x Q) in reg., one row per thread 1353c1e2affSSebastian Grimberg // B (Q x NB) in shared memory 1363c1e2affSSebastian Grimberg // C in registers -- one row per thread 1373c1e2affSSebastian Grimberg // no sync at the end of the function 1383c1e2affSSebastian Grimberg template <typename T, int P, int Q, int NB> 1399d15e85bSSebastian Grimberg static __device__ __inline__ void addmul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) { 1403c1e2affSSebastian Grimberg T rB[Q]; 1419d15e85bSSebastian Grimberg 1423c1e2affSSebastian Grimberg #pragma unroll 1433c1e2affSSebastian Grimberg for (int i = 0; i < NB; i++) { 1443c1e2affSSebastian Grimberg #pragma unroll 1459d15e85bSSebastian Grimberg for (int j = 0; j < Q; j++) { 1469d15e85bSSebastian Grimberg rB[j] = sB[i * Q + j]; 1473c1e2affSSebastian Grimberg } 1483c1e2affSSebastian Grimberg #pragma unroll 1499d15e85bSSebastian Grimberg for (int j = 0; j < Q; j++) { 1509d15e85bSSebastian Grimberg rC[i] += rA[j] * rB[j]; 1513c1e2affSSebastian Grimberg } 1523c1e2affSSebastian Grimberg } 1533c1e2affSSebastian Grimberg } 154f80f4a74SSebastian Grimberg 155f80f4a74SSebastian Grimberg #endif // CEED_MAGMA_COMMON_NONTENSOR_H 156