1 // Copyright (c) 2017-2022, Lawrence Livermore National Security, LLC and other CEED contributors. 2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details. 3 // 4 // SPDX-License-Identifier: BSD-2-Clause 5 // 6 // This file is part of CEED: http://github.com/ceed 7 8 /// @file 9 /// Internal header for MAGMA backend common non-tensor basis definitions 10 #ifndef CEED_MAGMA_COMMON_NONTENSOR_H 11 #define CEED_MAGMA_COMMON_NONTENSOR_H 12 13 #include "magma-common-defs.h" 14 15 //////////////////////////////////////////////////////////////////////////////// 16 // read A (no-trans) from global to reg. 17 // A is (P x Q) 18 // 2D thread config. with (P x BY) threads 19 // no sync at the end of the function 20 template <typename T, int P, int Q, int BY> 21 static __device__ __inline__ void read_A_notrans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) { 22 const int tid = ty * P + tx; 23 int i; 24 25 #pragma unroll 26 for (i = 0; i < P * Q - P * BY; i += P * BY) { 27 sA[i + tid] = dA[i + tid]; 28 } 29 if (i + tid < P * Q) { 30 sA[i + tid] = dA[i + tid]; 31 } 32 __syncthreads(); 33 34 #pragma unroll 35 for (int j = 0; j < Q; j++) { 36 rA[j] = sA[j * P + tx]; 37 } 38 } 39 40 //////////////////////////////////////////////////////////////////////////////// 41 // read A (trans) from global to reg. 42 // A is (P x Q) 43 // 2D thread config. with (P x BY) threads 44 // no sync at the end of the function 45 template <typename T, int P, int Q, int BY> 46 static __device__ __inline__ void read_A_trans_g2r_1D_nosync(const int tx, const int ty, const T *dA, T *sA, T rA[Q]) { 47 const int tid = ty * P + tx; 48 int i; 49 50 #pragma unroll 51 for (i = 0; i < P * Q - P * BY; i += P * BY) { 52 sA[i + tid] = dA[i + tid]; 53 } 54 if (i + tid < P * Q) { 55 sA[i + tid] = dA[i + tid]; 56 } 57 __syncthreads(); 58 59 #pragma unroll 60 for (int j = 0; j < Q; j++) { 61 rA[j] = sA[tx * Q + j]; 62 } 63 } 64 65 //////////////////////////////////////////////////////////////////////////////// 66 // read B from global to shared 67 // B is (Q x NB) 68 // 1D thread config. with (P x 1) threads 69 // no sync at the end of the function 70 template <typename T, int P, int Q, int NB> 71 static __device__ __inline__ void read_B_g2s_1D_nosync(const int tx, const int n, const T *dB, T *sB) { 72 int i; 73 74 if (n != NB) { 75 for (i = 0; i < Q * n - P; i += P) { 76 sB[i + tx] = dB[i + tx]; 77 } 78 } else { 79 #pragma unroll 80 for (i = 0; i < Q * NB - P; i += P) { 81 sB[i + tx] = dB[i + tx]; 82 } 83 } 84 if (i + tx < Q * n) { 85 sB[i + tx] = dB[i + tx]; 86 } 87 } 88 89 //////////////////////////////////////////////////////////////////////////////// 90 // write C from reg. to global 91 // C is (P x NB) 92 // 1D thread config. with (P x 1) threads 93 // no sync at the end of the function 94 template <typename T, int P, int Q, int NB> 95 static __device__ __inline__ void write_C_r2g_1D_nosync(const int tx, const int n, T rC[NB], T *dC) { 96 if (n != NB) { 97 for (int i = 0; i < n; i++) { 98 dC[i * P + tx] = rC[i]; 99 } 100 } else { 101 #pragma unroll 102 for (int i = 0; i < NB; i++) { 103 dC[i * P + tx] = rC[i]; 104 } 105 } 106 } 107 108 //////////////////////////////////////////////////////////////////////////////// 109 // multiply C = A x B using 1D threads in P x 1 config 110 // A (P x Q) in reg., one row per thread 111 // B (Q x NB) in shared memory 112 // C in registers -- one row per thread 113 // no sync at the end of the function 114 template <typename T, int P, int Q, int NB> 115 static __device__ __inline__ void mul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) { 116 T rB[Q]; 117 118 #pragma unroll 119 for (int i = 0; i < NB; i++) { 120 #pragma unroll 121 for (int j = 0; j < Q; j++) { 122 rB[j] = sB[i * Q + j]; 123 } 124 rC[i] = 0.0; 125 #pragma unroll 126 for (int j = 0; j < Q; j++) { 127 rC[i] += rA[j] * rB[j]; 128 } 129 } 130 } 131 132 //////////////////////////////////////////////////////////////////////////////// 133 // multiply C += A x B using 1D threads in P x 1 config 134 // A (P x Q) in reg., one row per thread 135 // B (Q x NB) in shared memory 136 // C in registers -- one row per thread 137 // no sync at the end of the function 138 template <typename T, int P, int Q, int NB> 139 static __device__ __inline__ void addmul_rAsBrC_1D_nosync(T rA[Q], T *sB, T rC[NB]) { 140 T rB[Q]; 141 142 #pragma unroll 143 for (int i = 0; i < NB; i++) { 144 #pragma unroll 145 for (int j = 0; j < Q; j++) { 146 rB[j] = sB[i * Q + j]; 147 } 148 #pragma unroll 149 for (int j = 0; j < Q; j++) { 150 rC[i] += rA[j] * rB[j]; 151 } 152 } 153 } 154 155 #endif // CEED_MAGMA_COMMON_NONTENSOR_H 156