xref: /libCEED/include/ceed/jit-source/magma/magma-basis-weight-1d.h (revision d83cf49fece5d7d5441d5b92eb712b904329a4d2)
1 // Copyright (c) 2017-2024, Lawrence Livermore National Security, LLC and other CEED contributors.
2 // All Rights Reserved. See the top-level LICENSE and NOTICE files for details.
3 //
4 // SPDX-License-Identifier: BSD-2-Clause
5 //
6 // This file is part of CEED:  http://github.com/ceed
7 
8 /// @file
9 /// Internal header for MAGMA tensor basis weight in 1D
10 
11 #include "magma-common-tensor.h"
12 
13 ////////////////////////////////////////////////////////////////////////////////
14 // weight basis action -- 1D
15 template <typename T, int Q>
16 static __device__ __inline__ void magma_weight_1d_device(const T *sTweight, T *sV, const int tx) {
17   // Assumptions
18   // 1. 1D thread configuration of size Q
19   // 2. The output sV is in shared memory -- size Q
20   if (tx < Q) {
21     sV[tx] = sTweight[tx];
22   }
23 }
24 
25 ////////////////////////////////////////////////////////////////////////////////
26 extern "C" __launch_bounds__(MAGMA_BASIS_BOUNDS(BASIS_Q, MAGMA_MAXTHREADS_1D)) __global__
27     void magma_weight_1d_kernel(const CeedScalar *dqweight1d, CeedScalar *dV, const int v_stride, const int nelem) {
28   MAGMA_DEVICE_SHARED(CeedScalar, shared_data)
29 
30   const int tx      = threadIdx.x;
31   const int ty      = threadIdx.y;
32   const int elem_id = (blockIdx.x * blockDim.y) + ty;
33 
34   if (elem_id >= nelem) return;
35 
36   // global memory pointers
37   dV += elem_id * v_stride;
38 
39   // shared memory pointers
40   CeedScalar *sTweight = (CeedScalar *)shared_data;
41   CeedScalar *sV       = sTweight + BASIS_Q;
42   sV += ty * BASIS_Q;
43 
44   // read dqweight_1d
45   if (ty == 0 && tx < BASIS_Q) {
46     sTweight[tx] = dqweight1d[tx];
47   }
48 
49   __syncthreads();
50   magma_weight_1d_device<CeedScalar, BASIS_Q>(sTweight, sV, tx);
51   __syncthreads();
52 
53   // write V
54   dV[tx] = sV[tx];
55 }
56